Source file backend_intf.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
(** Backend interface.
The [`op_*`] functions mirror tinygrad's UOps. A backend can execute them
eagerly, raise effects for a JIT, build a computation graph, etc.
The frontend handles broadcasting and shape validation, so each operation
simply produces a fresh tensor. *)
module type S = sig
type ('a, 'b) t
(** Opaque tensor handle.
['a] is the OCaml element type; ['b] tags the dtype. *)
type context
(** Backend execution context. Carries any state required by the
implementation (memory pools, command queues, ...). *)
val view : ('a, 'b) t -> View.t
(** Return the logical view metadata of [t]. *)
val dtype : ('a, 'b) t -> ('a, 'b) Dtype.t
(** Element type of [t]. *)
val context : ('a, 'b) t -> context
(** Execution context of [t]. *)
val data : ('a, 'b) t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t
(** Return the raw buffer of [t]. *)
val op_buffer :
context -> ('a, 'b) Dtype.t -> int -> ('a, 'b) t
(** Allocate a buffer of [size_in_elements] elements of [dtype]. *)
val op_const_scalar : context -> 'a -> ('a, 'b) Dtype.t -> ('a, 'b) t
(** Tensor containing a single scalar [value]. *)
val op_const_array :
context -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b) t
(** Tensor containing the elements of [array]. The array must be contiguous.
*)
val op_add : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Element-wise addition. *)
val op_mul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Element-wise multiplication. *)
val op_idiv : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Integer division, truncating. *)
val op_fdiv : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Floating-point division. *)
val op_max : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Element-wise maximum. *)
val op_mod : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Integer modulus. *)
val op_pow : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Raise [base] to [exponent]. *)
val op_cmplt : ('a, 'b) t -> ('a, 'b) t -> (int, Dtype.uint8_elt) t
(** Compare [<]. Returns 0 or 1 as uint8. *)
val op_cmpne : ('a, 'b) t -> ('a, 'b) t -> (int, Dtype.uint8_elt) t
(** Compare [<>]. Returns 0 or 1 as uint8. *)
val op_xor : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Bitwise XOR. *)
val op_or : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Bitwise OR. *)
val op_and : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Bitwise AND. *)
val op_neg : ('a, 'b) t -> ('a, 'b) t
(** Negation (logical not for bools). *)
val op_log2 : ('a, 'b) t -> ('a, 'b) t
(** Base-2 logarithm. *)
val op_exp2 : ('a, 'b) t -> ('a, 'b) t
(** Exponential base 2. *)
val op_sin : ('a, 'b) t -> ('a, 'b) t
(** Sine. *)
val op_sqrt : ('a, 'b) t -> ('a, 'b) t
(** Square root. *)
val op_recip : ('a, 'b) t -> ('a, 'b) t
(** Reciprocal. *)
val op_where :
(int, Dtype.uint8_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Select from [if_true] or [if_false] based on a boolean tensor. *)
val op_reduce_sum :
axes:int array -> keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
(** Sum over [axes]. Keeps reduced dimensions if [keepdims] is true. *)
val op_reduce_max :
axes:int array -> keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
(** Maximum over [axes]. Keeps reduced dimensions if [keepdims] is true. *)
val op_reduce_prod :
axes:int array -> keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
(** Product over [axes]. Keeps reduced dimensions if [keepdims] is true. *)
val op_expand : ('a, 'b) t -> int array -> ('a, 'b) t
(** Broadcast dimensions of size 1 to a new shape. *)
val op_reshape : ('a, 'b) t -> int array -> ('a, 'b) t
(** Change the logical shape without moving data. *)
val op_permute : ('a, 'b) t -> int array -> ('a, 'b) t
(** Reorder dimensions according to [axes]. *)
val op_pad : ('a, 'b) t -> (int * int) array -> 'a -> ('a, 'b) t
(** Pad with [fill_value] using the given configuration. *)
val op_shrink : ('a, 'b) t -> (int * int) array -> ('a, 'b) t
(** Slice according to the given start/stop pairs. *)
val op_flip : ('a, 'b) t -> bool array -> ('a, 'b) t
(** Flip dimensions where the boolean array is [true]. *)
val op_cat : ('a, 'b) t list -> int -> ('a, 'b) t
(** Concatenate tensors along [axis]. *)
val op_cast : ('a, 'b) t -> ('c, 'd) Dtype.t -> ('c, 'd) t
(** Cast elements to [target_dtype]. *)
val op_contiguous : ('a, 'b) t -> ('a, 'b) t
(** Return a C-contiguous tensor. May copy. *)
val op_copy : ('a, 'b) t -> ('a, 'b) t
(** Duplicate [t]. Result has its own buffer. *)
val op_assign : ('a, 'b) t -> ('a, 'b) t -> unit
(** Store [src] into [dst] at the given logical indices. *)
val op_threefry :
(int32, Dtype.int32_elt) t ->
(int32, Dtype.int32_elt) t ->
(int32, Dtype.int32_elt) t
(** Threefry random number generator. *)
val op_gather :
('a, 'b) t ->
(int32, Dtype.int32_elt) t ->
int ->
('a, 'b) t
(** Gather elements from [data] along [axis] using [indices]. Output shape
matches [indices]. Ranks of [data] and [indices] must match. Sizes of
[indices] dims != [axis] must be <= [data] corresponding dims. *)
val op_scatter :
?mode:[ `Set | `Add ] ->
?unique_indices:bool ->
('a, 'b) t ->
(int32, Dtype.int32_elt) t ->
('a, 'b) t ->
int ->
('a, 'b) t
(** Scatter [updates] into a new tensor shaped like [data_template] along
[axis] using [indices]. Returns a new tensor.
- [mode] specifies how to handle duplicate indices:
- [`Set] (default): last update wins
- [`Add]: accumulate updates at duplicate indices
- [unique_indices]: hint that indices are unique (optimization) *)
val op_unfold :
('a, 'b) t ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t
(** Unfold (im2col) operation. Extracts sliding local blocks from a batched
input tensor. For an input of shape (N, C, *spatial_dims), produces output
of shape (N, C * prod(kernel_size), L) where L is the number of blocks.
Works for any number of spatial dimensions. *)
val op_fold :
('a, 'b) t ->
output_size:int array ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t
(** Fold (col2im) operation. Combines an array of sliding local blocks into a
tensor. For an input of shape (N, C * prod(kernel_size), L), produces
output of shape (N, C, *output_size). Inverse of unfold. Overlapping
values are summed. Works for any number of spatial dimensions. *)
val op_matmul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Matrix multiplication. For 2D tensors, computes standard matrix
multiplication. For higher dimensions, performs batched matrix
multiplication on the last two dimensions, broadcasting batch dimensions
as needed. The last dimension of the first tensor must match the
second-to-last dimension of the second tensor. *)
end