package caisar

  1. Overview
  2. Docs

Source file tensor.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*  Copyright (C) 2024                                                    *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)
open Base

type ('a, 'b) t = ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t

let copy t =
  let t' = Bigarray.Genarray.(create (kind t) Bigarray.c_layout (dims t)) in
  Bigarray.Genarray.blit t t';
  t'

let of_tensor = copy
let to_tensor = copy

let create_1_float v =
  let t =
    Bigarray.Genarray.(create Bigarray.float64 Bigarray.c_layout [| 1 |])
  in
  Bigarray.Genarray.set t [| 0 |] v;
  t

let create_1_int64 v =
  let t = Bigarray.Genarray.(create Bigarray.int64 Bigarray.c_layout [| 1 |]) in
  Bigarray.Genarray.set t [| 0 |] v;
  t

let shape x = Shape.of_array @@ Bigarray.Genarray.dims x

let flatten t =
  let a = Bigarray.reshape_1 t (Shape.size (shape t)) in
  List.init (Bigarray.Array1.dim a) ~f:(fun i -> Bigarray.Array1.get a i)

let of_array1 shape t =
  Bigarray.reshape
    (copy @@ Bigarray.genarray_of_array1 t)
    (Shape.to_array_unsafe shape)

let reshape shape t = Bigarray.reshape t (Shape.to_array_unsafe shape)

let get = Bigarray.Genarray.get
OCaml

Innovation. Community. Security.