N-dimensional array operations for OCaml.
This module provides NumPy-style tensor operations. Tensors are immutable views over mutable buffers, supporting broadcasting, slicing, and efficient memory layout transformations.
Type System
The type ('a, 'b) t
represents a tensor where 'a
is the OCaml type of elements and 'b
is the bigarray element type. For example, (float, float32_elt) t
is a tensor of 32-bit floats.
Broadcasting
Operations automatically broadcast compatible shapes: each dimension must be equal or one of them must be 1. Shape |3; 1; 5|
broadcasts with |1; 4; 5|
to |3; 4; 5|
.
Memory Layout
Tensors can be C-contiguous or strided. Operations return views when possible (O(1)), otherwise copy (O(n)). Use is_contiguous
to check layout and contiguous
to ensure contiguity.
Type Definitions
('a, 'b) t
is a tensor with OCaml type 'a
and bigarray type 'b
.
Sourcetype ('a, 'b) dtype = ('a, 'b) Nx_core.Dtype.t =
| Float16 : (float, float16_elt) dtype
| Float32 : (float, float32_elt) dtype
| Float64 : (float, float64_elt) dtype
| Int8 : (int, int8_elt) dtype
| UInt8 : (int, uint8_elt) dtype
| Int16 : (int, int16_elt) dtype
| UInt16 : (int, uint16_elt) dtype
| Int32 : (int32, int32_elt) dtype
| Int64 : (int64, int64_elt) dtype
| Int : (int, int_elt) dtype
| NativeInt : (nativeint, nativeint_elt) dtype
| Complex32 : (Complex.t, complex32_elt) dtype
| Complex64 : (Complex.t, complex64_elt) dtype
Data type specification. Links OCaml types to bigarray element types.
Sourcetype index =
| I of int
| L of int list
| R of int list
Range start; stop; step
where stop is exclusive
Index specification for slicing
Array Properties
Functions to inspect array dimensions, memory layout, and data access.
data t
returns underlying bigarray buffer.
Buffer may contain data beyond tensor bounds for strided views. Direct access requires careful index computation using strides and offset.
Sourceval shape : ('a, 'b) t -> int array
shape t
returns dimensions. Empty array for scalars.
dtype t
returns data type.
Sourceval strides : ('a, 'b) t -> int array
strides t
returns byte strides for each dimension.
Sourceval stride : int -> ('a, 'b) t -> int
stride i t
returns byte stride for dimension i
.
Sourceval dims : ('a, 'b) t -> int array
dims t
is synonym for shape
.
Sourceval dim : int -> ('a, 'b) t -> int
dim i t
returns size of dimension i
.
ndim t
returns number of dimensions.
Sourceval itemsize : ('a, 'b) t -> int
itemsize t
returns bytes per element.
size t
returns total number of elements.
Sourceval numel : ('a, 'b) t -> int
numel t
is synonym for size
.
Sourceval nbytes : ('a, 'b) t -> int
nbytes t
returns size t * itemsize t
.
Sourceval offset : ('a, 'b) t -> int
offset t
returns element offset in underlying buffer.
Sourceval is_c_contiguous : ('a, 'b) t -> bool
is_c_contiguous t
returns true if elements are contiguous in C order.
to_bigarray t
converts to bigarray.
Always returns contiguous copy with same shape. Use for interop with libraries expecting bigarrays.
# let t = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |]
val t : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6]]
# Bigarray.Genarray.dims (to_bigarray t) = shape t
- : bool = true
Sourceval to_array : ('a, 'b) t -> 'a array
to_array t
converts to OCaml array.
Flattens tensor to 1-D array in row-major (C) order. Always copies.
# let t = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |]
val t : (int32, int32_elt) t = [[1, 2],
[3, 4]]
# to_array t
- : int32 array = [|1l; 2l; 3l; 4l|]
Array Creation
Functions to create and initialize arrays.
Sourceval create : ('a, 'b) dtype -> int array -> 'a array -> ('a, 'b) t
create dtype shape data
creates tensor from array data
.
Length of data
must equal product of shape
.
# create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |]
- : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6]]
Sourceval init : ('a, 'b) dtype -> int array -> (int array -> 'a) -> ('a, 'b) t
init dtype shape f
creates tensor where element at indices i
has value f i
.
Function f
receives array of indices for each position. Useful for creating position-dependent values.
# init int32 [| 2; 3 |] (fun i -> Int32.of_int (i.(0) + i.(1)))
- : (int32, int32_elt) t = [[0, 1, 2],
[1, 2, 3]]
# init float32 [| 3; 3 |] (fun i -> if i.(0) = i.(1) then 1. else 0.)
- : (float, float32_elt) t = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
empty dtype shape
allocates uninitialized tensor.
Sourceval full : ('a, 'b) dtype -> int array -> 'a -> ('a, 'b) t
full dtype shape value
creates tensor filled with value
.
# full float32 [| 2; 3 |] 3.14
- : (float, float32_elt) t = [[3.14, 3.14, 3.14],
[3.14, 3.14, 3.14]]
ones dtype shape
creates tensor filled with ones.
zeros dtype shape
creates tensor filled with zeros.
scalar dtype value
creates scalar tensor containing value
.
Sourceval empty_like : ('a, 'b) t -> ('a, 'b) t
empty_like t
creates uninitialized tensor with same shape and dtype as t
.
Sourceval full_like : ('a, 'b) t -> 'a -> ('a, 'b) t
full_like t value
creates tensor shaped like t
filled with value
.
Sourceval ones_like : ('a, 'b) t -> ('a, 'b) t
ones_like t
creates tensor shaped like t
filled with ones.
Sourceval zeros_like : ('a, 'b) t -> ('a, 'b) t
zeros_like t
creates tensor shaped like t
filled with zeros.
Sourceval scalar_like : ('a, 'b) t -> 'a -> ('a, 'b) t
scalar_like t value
creates scalar with same dtype as t
.
Sourceval eye : ?m:int -> ?k:int -> ('a, 'b) dtype -> int -> ('a, 'b) t
eye ?m ?k dtype n
creates matrix with ones on k-th diagonal.
Default m = n
(square), k = 0
(main diagonal). Positive k
shifts diagonal above main, negative below.
# eye int32 3
- : (int32, int32_elt) t = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
# eye ~k:1 int32 3
- : (int32, int32_elt) t = [[0, 1, 0],
[0, 0, 1],
[0, 0, 0]]
# eye ~m:2 ~k:(-1) int32 3
- : (int32, int32_elt) t = [[0, 0, 0],
[1, 0, 0]]
identity dtype n
creates n×n identity matrix.
Equivalent to eye dtype n
. Square matrix with ones on main diagonal, zeros elsewhere.
# identity int32 3
- : (int32, int32_elt) t = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
Sourceval arange : ('a, 'b) dtype -> int -> int -> int -> ('a, 'b) t
arange dtype start stop step
generates values from start
to [stop)
.
Step must be non-zero. Result length is (stop - start) / step
rounded toward zero.
# arange int32 0 10 2
- : (int32, int32_elt) t = [0, 2, 4, 6, 8]
# arange int32 5 0 (-1)
- : (int32, int32_elt) t = [5, 4, 3, 2, 1]
Sourceval arange_f : (float, 'a) dtype -> float -> float -> float -> (float, 'a) t
arange_f dtype start stop step
generates float values from start
to [stop)
.
Like arange
but for floating-point ranges. Handles fractional steps. Due to floating-point precision, final value may differ slightly from expected.
# arange_f float32 0. 1. 0.2
- : (float, float32_elt) t = [0, 0.2, 0.4, 0.6, 0.8]
# arange_f float32 1. 0. (-0.25)
- : (float, float32_elt) t = [1, 0.75, 0.5, 0.25]
Sourceval linspace :
('a, 'b) dtype ->
?endpoint:bool ->
float ->
float ->
int ->
('a, 'b) t
linspace dtype ?endpoint start stop count
generates count
evenly spaced values from start
to stop
.
If endpoint
is true (default), stop
is included.
# linspace float32 ~endpoint:true 0. 10. 5
- : (float, float32_elt) t = [0, 2.5, 5, 7.5, 10]
# linspace float32 ~endpoint:false 0. 10. 5
- : (float, float32_elt) t = [0, 2, 4, 6, 8]
Sourceval logspace :
(float, 'a) dtype ->
?endpoint:bool ->
?base:float ->
float ->
float ->
int ->
(float, 'a) t
logspace dtype ?endpoint ?base start_exp stop_exp count
generates values evenly spaced on log scale.
Returns base ** x
where x ranges from start_exp
to stop_exp
. Default base = 10.0
.
# logspace float32 0. 2. 3
- : (float, float32_elt) t = [1, 10, 100]
# logspace float32 ~base:2.0 0. 3. 4
- : (float, float32_elt) t = [1, 2, 4, 8]
Sourceval geomspace :
(float, 'a) dtype ->
?endpoint:bool ->
float ->
float ->
int ->
(float, 'a) t
geomspace dtype ?endpoint start stop count
generates values evenly spaced on geometric (multiplicative) scale.
# geomspace float32 1. 1000. 4
- : (float, float32_elt) t = [1, 10, 100, 1000]
Sourceval meshgrid :
?indexing:[ `xy | `ij ] ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t * ('a, 'b) t
meshgrid ?indexing x y
creates coordinate grids from 1D arrays.
Returns (X, Y) where X and Y are 2D arrays representing grid coordinates.
`xy
(default): Cartesian indexing - X changes along columns, Y changes along rows`ij
: Matrix indexing - X changes along rows, Y changes along columns
# let x = linspace float32 0. 2. 3 in
let y = linspace float32 0. 1. 2 in
meshgrid x y
- : (float, float32_elt) t * (float, float32_elt) t =
([[0, 1, 2],
[0, 1, 2]], [[0, 0, 0],
[1, 1, 1]])
of_bigarray ba
creates tensor from bigarray.
Zero-copy when bigarray is contiguous. Creates view sharing same memory. Modifications to either affect both.
# let ba = Bigarray.Array2.create Float32 C_layout 2 3 in
let t = of_bigarray (Bigarray.genarray_of_array2 ba) in
t
- : (float, float32_elt) t = [[0, 0, 0],
[0, 0, 0]]
Random Number Generation
Functions to generate arrays with random values.
Sourceval rand : ('a, 'b) dtype -> ?seed:int -> int array -> ('a, 'b) t
rand dtype ?seed shape
generates uniform random values in [0, 1)
.
Only supports float dtypes. Same seed produces same sequence.
Sourceval randn : ('a, 'b) dtype -> ?seed:int -> int array -> ('a, 'b) t
randn dtype ?seed shape
generates standard normal random values.
Mean 0, variance 1. Uses Box-Muller transform for efficiency. Only supports float dtypes. Same seed produces same sequence.
Sourceval randint :
('a, 'b) dtype ->
?seed:int ->
?high:int ->
int array ->
int ->
('a, 'b) t
randint dtype ?seed ?high shape low
generates integers in [low, high)
.
Uniform distribution over range. Default high = 10
. Note: high
is exclusive (NumPy convention).
Shape Manipulation
Functions to reshape, transpose, and rearrange arrays.
Sourceval reshape : int array -> ('a, 'b) t -> ('a, 'b) t
reshape shape t
returns view with new shape.
At most one dimension can be -1 (inferred from total elements). Product of dimensions must match total elements. Returns view when possible (O(1)), copies if tensor is not contiguous and cannot be viewed.
# let t = create int32 [|2; 3|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
reshape [|6|] t
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
# let t = create int32 [|6|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
reshape [|3; -1|] t
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]
Sourceval broadcast_to : int array -> ('a, 'b) t -> ('a, 'b) t
broadcast_to shape t
broadcasts tensor to target shape.
Shapes must be broadcast-compatible: dimensions align from right, each must be equal or source must be 1. Returns view (no copy) with zero strides for broadcast dimensions.
# let t = create int32 [|1; 3|] [|1l; 2l; 3l|] in
broadcast_to [|3; 3|] t
- : (int32, int32_elt) t = [[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]
# let t = ones float32 [|3; 1|] in
shape (broadcast_to [|2; 3; 4|] t)
- : int array = [|2; 3; 4|]
Sourceval broadcasted :
?reverse:bool ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t * ('a, 'b) t
broadcasted ?reverse t1 t2
broadcasts tensors to common shape.
Returns views of both tensors broadcast to compatible shape. If reverse
is true, returns (t2', t1')
instead of (t1', t2')
. Useful before element-wise operations.
# let t1 = ones float32 [|3;1|] in
let t2 = ones float32 [|1;5|] in
let t1', t2' = broadcasted t1 t2 in
shape t1', shape t2'
- : int array * int array = ([|3; 5|], [|3; 5|])
Sourceval expand : int array -> ('a, 'b) t -> ('a, 'b) t
expand shape t
broadcasts tensor where -1
keeps original dimension.
Like broadcast_to
but -1
preserves existing dimension size. Adds dimensions on left if needed.
# let t = ones float32 [|1; 4; 1|] in
shape (expand [|3; -1; 5|] t)
- : int array = [|3; 4; 5|]
# let t = ones float32 [|5; 5|] in
shape (expand [|-1; -1|] t)
- : int array = [|5; 5|]
Sourceval flatten : ?start_dim:int -> ?end_dim:int -> ('a, 'b) t -> ('a, 'b) t
flatten ?start_dim ?end_dim t
collapses dimensions into single dimension.
Default start_dim = 0
, end_dim = -1
(last). Negative indices count from end. Dimensions start_dim
through end_dim
inclusive are flattened.
# flatten (zeros float32 [| 2; 3; 4 |]) |> shape
- : int array = [|24|]
# flatten ~start_dim:1 ~end_dim:2 (zeros float32 [| 2; 3; 4; 5 |]) |> shape
- : int array = [|2; 12; 5|]
Sourceval unflatten : int -> int array -> ('a, 'b) t -> ('a, 'b) t
unflatten dim sizes t
expands dimension dim
into multiple dimensions.
Product of sizes
must equal size of dimension dim
. At most one dimension can be -1 (inferred). Inverse of flatten
.
# unflatten 1 [| 3; 4 |] (zeros float32 [| 2; 12; 5 |]) |> shape
- : int array = [|2; 3; 4; 5|]
# unflatten 0 [| -1; 2 |] (ones float32 [| 6; 5 |]) |> shape
- : int array = [|3; 2; 5|]
Sourceval ravel : ('a, 'b) t -> ('a, 'b) t
ravel t
returns contiguous 1-D view.
Equivalent to flatten t
but always returns contiguous result. Use when you need both flattening and contiguity.
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
ravel x
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
# let t = transpose (ones float32 [| 3; 4 |]) in
is_c_contiguous t
- : bool = false
# let t_ravel = ravel t in
is_c_contiguous t_ravel
- : bool = true
Sourceval squeeze : ?axes:int array -> ('a, 'b) t -> ('a, 'b) t
squeeze ?axes t
removes dimensions of size 1.
If axes
specified, only removes those dimensions. Negative indices count from end. Returns view when possible.
# squeeze (ones float32 [| 1; 3; 1; 4 |]) |> shape
- : int array = [|3; 4|]
# squeeze ~axes:[| 0; 2 |] (ones float32 [| 1; 3; 1; 4 |]) |> shape
- : int array = [|3; 4|]
# squeeze ~axes:[| -1 |] (ones float32 [| 3; 4; 1 |]) |> shape
- : int array = [|3; 4|]
Sourceval unsqueeze : ?axes:int array -> ('a, 'b) t -> ('a, 'b) t
unsqueeze ?axes t
inserts dimensions of size 1 at specified positions.
Axes refer to positions in result tensor. Must be in range 0, ndim
.
# unsqueeze ~axes:[| 0; 2 |] (create float32 [| 3 |] [| 1.; 2.; 3. |]) |> shape
- : int array = [|1; 3; 1|]
# unsqueeze ~axes:[| 1 |] (create float32 [| 2 |] [| 5.; 6. |]) |> shape
- : int array = [|2; 1|]
Sourceval squeeze_axis : int -> ('a, 'b) t -> ('a, 'b) t
squeeze_axis axis t
removes dimension axis
if size is 1.
Sourceval unsqueeze_axis : int -> ('a, 'b) t -> ('a, 'b) t
unsqueeze_axis axis t
inserts dimension of size 1 at axis
.
Sourceval expand_dims : int array -> ('a, 'b) t -> ('a, 'b) t
Sourceval transpose : ?axes:int array -> ('a, 'b) t -> ('a, 'b) t
transpose ?axes t
permutes dimensions.
Default reverses all dimensions. axes
must be permutation of 0..ndim-1
. Returns view (no copy) with adjusted strides.
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
transpose x
- : (int32, int32_elt) t = [[1, 4],
[2, 5],
[3, 6]]
# transpose ~axes:[| 2; 0; 1 |] (zeros float32 [| 2; 3; 4 |]) |> shape
- : int array = [|4; 2; 3|]
# let id = transpose ~axes:[| 1; 0 |] in
id == transpose
- : bool = false
Sourceval flip : ?axes:int array -> ('a, 'b) t -> ('a, 'b) t
flip ?axes t
reverses order along specified dimensions.
Default flips all dimensions.
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
flip x
- : (int32, int32_elt) t = [[6, 5, 4],
[3, 2, 1]]
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
flip ~axes:[| 1 |] x
- : (int32, int32_elt) t = [[3, 2, 1],
[6, 5, 4]]
Sourceval moveaxis : int -> int -> ('a, 'b) t -> ('a, 'b) t
moveaxis src dst t
moves dimension from src
to dst
.
Sourceval swapaxes : int -> int -> ('a, 'b) t -> ('a, 'b) t
swapaxes axis1 axis2 t
exchanges two dimensions.
Sourceval roll : ?axis:int -> int -> ('a, 'b) t -> ('a, 'b) t
roll ?axis shift t
shifts elements along axis.
Elements shifted beyond last position wrap to beginning. If axis
not specified, shifts flattened tensor. Negative shift rolls backward.
# let x = create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |] in
roll 2 x
- : (int32, int32_elt) t = [4, 5, 1, 2, 3]
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
roll ~axis:1 1 x
- : (int32, int32_elt) t = [[3, 1, 2],
[6, 4, 5]]
# let x = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
roll ~axis:0 (-1) x
- : (int32, int32_elt) t = [[3, 4],
[1, 2]]
Sourceval pad : (int * int) array -> 'a -> ('a, 'b) t -> ('a, 'b) t
pad padding value t
pads tensor with value
.
padding
specifies (before, after) for each dimension. Length must match tensor dimensions. Negative padding not allowed.
# let x = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
pad [| (1, 1); (2, 2) |] 0l x
- : (int32, int32_elt) t =
[[0, 0, 0, 0, 0, 0],
[0, 0, 1, 2, 0, 0],
[0, 0, 3, 4, 0, 0],
[0, 0, 0, 0, 0, 0]]
Sourceval shrink : (int * int) array -> ('a, 'b) t -> ('a, 'b) t
shrink ranges t
extracts slice from start
to stop
(exclusive) for each dimension.
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
shrink [| (1, 3); (0, 2) |] x
- : (int32, int32_elt) t = [[4, 5],
[7, 8]]
Sourceval tile : int array -> ('a, 'b) t -> ('a, 'b) t
tile reps t
constructs tensor by repeating t
.
reps
specifies repetitions per dimension. If longer than ndim, prepends dimensions. Zero repetitions create empty tensor.
# let x = create int32 [| 1; 2 |] [| 1l; 2l |] in
tile [| 2; 3 |] x
- : (int32, int32_elt) t = [[1, 2, 1, 2, 1, 2],
[1, 2, 1, 2, 1, 2]]
# let x = create int32 [| 2 |] [| 1l; 2l |] in
tile [| 2; 1; 3 |] x |> shape
- : int array = [|2; 1; 6|]
Sourceval repeat : ?axis:int -> int -> ('a, 'b) t -> ('a, 'b) t
repeat ?axis count t
repeats elements count
times.
If axis
not specified, repeats flattened tensor.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
repeat 2 x
- : (int32, int32_elt) t = [1, 1, 2, 2, 3, 3]
# let x = create int32 [| 1; 2 |] [| 1l; 2l |] in
repeat ~axis:0 3 x
- : (int32, int32_elt) t = [[1, 2],
[1, 2],
[1, 2]]
Array Combination and Splitting
Functions to join and split arrays.
Sourceval concatenate : ?axis:int -> ('a, 'b) t list -> ('a, 'b) t
concatenate ?axis ts
joins tensors along existing axis.
All tensors must have same shape except on concatenation axis. If axis
not specified, flattens all tensors then concatenates. Returns contiguous result.
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 1; 2 |] [| 5l; 6l |] in
concatenate ~axis:0 [x1; x2]
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 1; 2 |] [| 5l; 6l |] in
concatenate [x1; x2]
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
Sourceval stack : ?axis:int -> ('a, 'b) t list -> ('a, 'b) t
stack ?axis ts
joins tensors along new axis.
All tensors must have identical shape. Result rank is input rank + 1. Default axis=0. Negative axis counts from end of result shape.
# let x1 = create int32 [| 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2 |] [| 3l; 4l |] in
stack [x1; x2]
- : (int32, int32_elt) t = [[1, 2],
[3, 4]]
# let x1 = create int32 [| 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2 |] [| 3l; 4l |] in
stack ~axis:1 [x1; x2]
- : (int32, int32_elt) t = [[1, 3],
[2, 4]]
# stack ~axis:(-1) [ones float32 [| 2; 3 |]; zeros float32 [| 2; 3 |]] |> shape
- : int array = [|2; 3; 2|]
Sourceval vstack : ('a, 'b) t list -> ('a, 'b) t
vstack ts
stacks tensors vertically (row-wise).
1-D tensors are treated as row vectors (shape 1;n
). Higher-D tensors concatenate along axis 0. All tensors must have same shape except possibly first dimension.
# let x1 = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let x2 = create int32 [| 3 |] [| 4l; 5l; 6l |] in
vstack [x1; x2]
- : (int32, int32_elt) t = [[1, 2, 3],
[4, 5, 6]]
# let x1 = create int32 [| 1; 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2; 2 |] [| 3l; 4l; 5l; 6l |] in
vstack [x1; x2]
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]
Sourceval hstack : ('a, 'b) t list -> ('a, 'b) t
hstack ts
stacks tensors horizontally (column-wise).
1-D tensors concatenate directly. Higher-D tensors concatenate along axis 1. For 1-D arrays of different lengths, use vstack to make 2-D first.
# let x1 = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let x2 = create int32 [| 3 |] [| 4l; 5l; 6l |] in
hstack [x1; x2]
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
# let x1 = create int32 [| 2; 1 |] [| 1l; 2l |] in
let x2 = create int32 [| 2; 1 |] [| 3l; 4l |] in
hstack [x1; x2]
- : (int32, int32_elt) t = [[1, 3],
[2, 4]]
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 2; 1 |] [| 5l; 6l |] in
hstack [x1; x2]
- : (int32, int32_elt) t = [[1, 2, 5],
[3, 4, 6]]
Sourceval dstack : ('a, 'b) t list -> ('a, 'b) t
dstack ts
stacks tensors depth-wise (along third axis).
Tensors are reshaped to at least 3-D before concatenation:
- 1-D shape
n
→ 1;n;1
- 2-D shape
m;n
→ m;n;1
- 3-D+ unchanged
# let x1 = create int32 [| 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2 |] [| 3l; 4l |] in
dstack [x1; x2]
- : (int32, int32_elt) t = [[[1, 3],
[2, 4]]]
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 2; 2 |] [| 5l; 6l; 7l; 8l |] in
dstack [x1; x2]
- : (int32, int32_elt) t = [[[1, 5],
[2, 6]],
[[3, 7],
[4, 8]]]
Sourceval broadcast_arrays : ('a, 'b) t list -> ('a, 'b) t list
broadcast_arrays ts
broadcasts all tensors to common shape.
Finds the common broadcast shape and returns list of views with that shape. Broadcasting rules: dimensions align right, each must be 1 or equal.
# let x1 = ones float32 [| 3; 1 |] in
let x2 = ones float32 [| 1; 5 |] in
broadcast_arrays [x1; x2] |> List.map shape
- : int array list = [[|3; 5|]; [|3; 5|]]
# let x1 = scalar float32 5. in
let x2 = ones float32 [| 2; 3; 4 |] in
broadcast_arrays [x1; x2] |> List.map shape
- : int array list = [[|2; 3; 4|]; [|2; 3; 4|]]
Sourceval array_split :
axis:int ->
[< `Count of int | `Indices of int list ] ->
('a, 'b) t ->
('a, 'b) t list
array_split ~axis sections t
splits tensor into multiple parts.
`Count n
divides into n parts as evenly as possible. Extra elements go to first parts. `Indices [i1;i2;...]
splits at indices creating start:i1
, i1:i2
, i2:end
.
# let x = create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |] in
array_split ~axis:0 (`Count 3) x
- : (int32, int32_elt) t list = [[1, 2]; [3, 4]; [5]]
# let x = create int32 [| 6 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
array_split ~axis:0 (`Indices [ 2; 4 ]) x
- : (int32, int32_elt) t list = [[1, 2]; [3, 4]; [5, 6]]
Sourceval split : axis:int -> int -> ('a, 'b) t -> ('a, 'b) t list
split ~axis sections t
splits into equal parts.
# let x = create int32 [| 4; 2 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l |] in
split ~axis:0 2 x
- : (int32, int32_elt) t list = [[[1, 2],
[3, 4]]; [[5, 6],
[7, 8]]]
Type Conversion and Copying
Functions to convert between types and create copies.
cast dtype t
converts elements to new dtype.
Returns copy with same values in new type.
# let x = create float32 [| 3 |] [| 1.5; 2.7; 3.1 |] in
cast int32 x
- : (int32, int32_elt) t = [1, 2, 3]
astype dtype t
is synonym for cast
.
Sourceval contiguous : ('a, 'b) t -> ('a, 'b) t
contiguous t
returns C-contiguous tensor.
Returns t
unchanged if already contiguous (O(1)), otherwise creates contiguous copy (O(n)). Use before operations requiring direct memory access.
# let t = transpose (ones float32 [| 3; 4 |]) in
is_c_contiguous (contiguous t)
- : bool = true
Sourceval copy : ('a, 'b) t -> ('a, 'b) t
copy t
returns deep copy.
Always allocates new memory and copies data. Result is contiguous.
# let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
let y = copy x in
set_item [ 0 ] 999. y;
x, y
- : (float, float32_elt) t * (float, float32_elt) t =
([1, 2, 3], [999, 2, 3])
Sourceval blit : ('a, 'b) t -> ('a, 'b) t -> unit
blit src dst
copies src
into dst
.
Shapes must match exactly. Handles broadcasting internally. Modifies dst
in-place.
let dst = zeros float32 [| 3; 3 |] in
blit (ones float32 [| 3; 3 |]) dst
(* dst now contains all 1s *)
Sourceval fill : 'a -> ('a, 'b) t -> ('a, 'b) t
fill value t
sets all elements to value
.
Modifies t
in-place and returns it for chaining.
# let x = zeros float32 [| 2; 3 |] in
let y = fill 5. x in
y == x
- : bool = true
Element Access and Slicing
Functions to access and modify array elements.
slice indices t
extracts subtensor.
I n
: select index n (reduces dimension)L [i;j;k]
: fancy indexing - select indices i, j, kR [start;stop;step]
: range [start, stop)
with step
Stop is exclusive. Negative indices count from end. Missing indices select all. Returns view when possible.
# let x = create int32 [| 2; 4 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l |] in
slice [ I 1 ] x
- : (int32, int32_elt) t = [5, 6, 7, 8]
# let x = create int32 [| 5 |] [| 0l; 1l; 2l; 3l; 4l |] in
slice [ R [ 1; 3 ] ] x
- : (int32, int32_elt) t = [1, 2]
Sourceval set_slice : index list -> ('a, 'b) t -> ('a, 'b) t -> unit
set_slice indices t value
assigns value
to slice.
Sourceval slice_ranges :
?steps:int list ->
int list ->
int list ->
('a, 'b) t ->
('a, 'b) t
slice_ranges ?steps starts stops t
extracts ranges.
Equivalent to slice [R[s0;e0;st0]; R[s1;e1;st1]; ...] t
. Lists must have same length ≤ ndim. Default step is 1. Missing dimensions select all.
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
slice_ranges [ 0; 1 ] [ 2; 3 ] x
- : (int32, int32_elt) t = [[2, 3],
[5, 6]]
# slice_ranges ~steps:[ 2; 1 ] [ 0; 0 ] [ 4; 2 ] (eye int32 4)
- : (int32, int32_elt) t = [[1, 0],
[0, 0]]
Sourceval set_slice_ranges :
?steps:int list ->
int list ->
int list ->
('a, 'b) t ->
('a, 'b) t ->
unit
set_slice_ranges ?steps starts stops t value
assigns to ranges.
Like slice_ranges
but assigns value
to selected region. Value is broadcast to target shape if needed.
# let x = zeros float32 [| 3; 3 |] in
set_slice_ranges [ 1; 2 ] [ 2; 3 ] x (ones float32 [| 1; 1 |]);
get_item [ 1; 2 ] x
- : float = 1.
Sourceval get : int list -> ('a, 'b) t -> ('a, 'b) t
get indices t
returns subtensor at indices.
Indexes from outermost dimension. Returns scalar tensor if all dimensions indexed, otherwise returns view of remaining dimensions.
# let x = create int32 [| 2; 2; 2 |] [| 0l; 1l; 2l; 3l; 4l; 5l; 6l; 7l |] in
get [ 1; 1; 1 ] x
- : (int32, int32_elt) t = 7
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
get [ 1 ] x
- : (int32, int32_elt) t = [4, 5, 6]
Sourceval set : int list -> ('a, 'b) t -> ('a, 'b) t -> unit
set indices t value
assigns value
at indices.
Sourceval get_item : int list -> ('a, 'b) t -> 'a
get_item indices t
returns scalar value at indices.
Must provide indices for all dimensions.
Sourceval set_item : int list -> 'a -> ('a, 'b) t -> unit
set_item indices value t
sets scalar value at indices.
Must provide indices for all dimensions. Modifies tensor in-place.
Basic Arithmetic Operations
Element-wise arithmetic operations and their variants.
Sourceval add : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
add t1 t2
computes element-wise sum with broadcasting.
Sourceval add_s : ('a, 'b) t -> 'a -> ('a, 'b) t
add_s t scalar
adds scalar to each element.
Sourceval iadd : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
iadd target value
adds value
to target
in-place.
Returns modified target
.
Sourceval radd_s : 'a -> ('a, 'b) t -> ('a, 'b) t
radd_s scalar t
is add_s t scalar
.
Sourceval iadd_s : ('a, 'b) t -> 'a -> ('a, 'b) t
iadd_s t scalar
adds scalar to t
in-place.
Sourceval sub : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
sub t1 t2
computes element-wise difference with broadcasting.
Sourceval sub_s : ('a, 'b) t -> 'a -> ('a, 'b) t
sub_s t scalar
subtracts scalar from each element.
Sourceval rsub_s : 'a -> ('a, 'b) t -> ('a, 'b) t
rsub_s scalar t
computes scalar - t
.
Sourceval isub : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
isub target value
subtracts value
from target
in-place.
Sourceval isub_s : ('a, 'b) t -> 'a -> ('a, 'b) t
isub_s t scalar
subtracts scalar from t
in-place.
Sourceval mul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
mul t1 t2
computes element-wise product with broadcasting.
Sourceval mul_s : ('a, 'b) t -> 'a -> ('a, 'b) t
mul_s t scalar
multiplies each element by scalar.
Sourceval rmul_s : 'a -> ('a, 'b) t -> ('a, 'b) t
rmul_s scalar t
is mul_s t scalar
.
Sourceval imul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
imul target value
multiplies target
by value
in-place.
Sourceval imul_s : ('a, 'b) t -> 'a -> ('a, 'b) t
imul_s t scalar
multiplies t
by scalar in-place.
Sourceval div : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
div t1 t2
computes element-wise division.
True division for floats (result is float). Integer division for integers (truncates toward zero). Complex division follows standard rules.
# let x = create float32 [| 3 |] [| 7.; 8.; 9. |] in
let y = create float32 [| 3 |] [| 2.; 2.; 2. |] in
div x y
- : (float, float32_elt) t = [3.5, 4, 4.5]
# let x = create int32 [| 3 |] [| 7l; 8l; 9l |] in
let y = create int32 [| 3 |] [| 2l; 2l; 2l |] in
div x y
- : (int32, int32_elt) t = [3, 4, 4]
# let x = create int32 [| 2 |] [| -7l; 8l |] in
let y = create int32 [| 2 |] [| 2l; 2l |] in
div x y
- : (int32, int32_elt) t = [-3, 4]
Sourceval div_s : ('a, 'b) t -> 'a -> ('a, 'b) t
div_s t scalar
divides each element by scalar.
Sourceval rdiv_s : 'a -> ('a, 'b) t -> ('a, 'b) t
rdiv_s scalar t
computes scalar / t
.
Sourceval idiv : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
idiv target value
divides target
by value
in-place.
Sourceval idiv_s : ('a, 'b) t -> 'a -> ('a, 'b) t
idiv_s t scalar
divides t
by scalar in-place.
Sourceval pow : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
pow base exponent
computes element-wise power.
Sourceval pow_s : ('a, 'b) t -> 'a -> ('a, 'b) t
pow_s t scalar
raises each element to scalar power.
Sourceval rpow_s : 'a -> ('a, 'b) t -> ('a, 'b) t
rpow_s scalar t
computes scalar ** t
.
Sourceval ipow : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
ipow target exponent
raises target
to exponent
in-place.
Sourceval ipow_s : ('a, 'b) t -> 'a -> ('a, 'b) t
ipow_s t scalar
raises t
to scalar power in-place.
Sourceval mod_ : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
mod_ t1 t2
computes element-wise modulo.
Sourceval mod_s : ('a, 'b) t -> 'a -> ('a, 'b) t
mod_s t scalar
computes modulo scalar for each element.
Sourceval rmod_s : 'a -> ('a, 'b) t -> ('a, 'b) t
rmod_s scalar t
computes scalar mod t
.
Sourceval imod : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
imod target divisor
computes modulo in-place.
Sourceval imod_s : ('a, 'b) t -> 'a -> ('a, 'b) t
imod_s t scalar
computes modulo scalar in-place.
Sourceval neg : ('a, 'b) t -> ('a, 'b) t
neg t
negates all elements.
Mathematical Functions
Unary mathematical operations and special functions.
Sourceval abs : ('a, 'b) t -> ('a, 'b) t
abs t
computes absolute value.
Sourceval sign : ('a, 'b) t -> ('a, 'b) t
sign t
returns -1, 0, or 1 based on sign.
For unsigned types, returns 1 for all non-zero values, 0 for zero.
# let x = create float32 [| 3 |] [| -2.; 0.; 3.5 |] in
sign x
- : (float, float32_elt) t = [-1, 0, 1]
Sourceval square : ('a, 'b) t -> ('a, 'b) t
square t
computes element-wise square.
Sourceval sqrt : ('a, 'b) t -> ('a, 'b) t
sqrt t
computes element-wise square root.
Sourceval rsqrt : ('a, 'b) t -> ('a, 'b) t
rsqrt t
computes reciprocal square root.
Sourceval recip : ('a, 'b) t -> ('a, 'b) t
recip t
computes element-wise reciprocal.
Sourceval log : (float, 'a) t -> (float, 'a) t
log t
computes natural logarithm.
Sourceval log2 : ('a, 'b) t -> ('a, 'b) t
log2 t
computes base-2 logarithm.
Sourceval exp : (float, 'a) t -> (float, 'a) t
exp t
computes exponential.
Sourceval exp2 : ('a, 'b) t -> ('a, 'b) t
Sourceval sin : ('a, 'b) t -> ('a, 'b) t
Sourceval cos : (float, 'a) t -> (float, 'a) t
Sourceval tan : (float, 'a) t -> (float, 'a) t
Sourceval asin : (float, 'a) t -> (float, 'a) t
Sourceval acos : (float, 'a) t -> (float, 'a) t
acos t
computes arccosine.
Sourceval atan : (float, 'a) t -> (float, 'a) t
atan t
computes arctangent.
Sourceval atan2 : (float, 'a) t -> (float, 'a) t -> (float, 'a) t
atan2 y x
computes arctangent of y/x using signs to determine quadrant.
Returns angle in radians in range -π, π
. Handles x=0 correctly.
# let y = scalar float32 1. in
let x = scalar float32 1. in
atan2 y x |> get_item [] |> Float.round
- : float = 1.
# let y = scalar float32 1. in
let x = scalar float32 0. in
atan2 y x |> get_item [] |> Float.round
- : float = 2.
# let y = scalar float32 0. in
let x = scalar float32 0. in
atan2 y x |> get_item []
- : float = 0.
Sourceval sinh : (float, 'a) t -> (float, 'a) t
sinh t
computes hyperbolic sine.
Sourceval cosh : (float, 'a) t -> (float, 'a) t
cosh t
computes hyperbolic cosine.
Sourceval tanh : (float, 'a) t -> (float, 'a) t
tanh t
computes hyperbolic tangent.
Sourceval asinh : (float, 'a) t -> (float, 'a) t
asinh t
computes inverse hyperbolic sine.
Sourceval acosh : (float, 'a) t -> (float, 'a) t
acosh t
computes inverse hyperbolic cosine.
Sourceval atanh : (float, 'a) t -> (float, 'a) t
atanh t
computes inverse hyperbolic tangent.
Sourceval hypot : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
hypot x y
computes sqrt(x² + y²) avoiding overflow.
Uses numerically stable algorithm: max * sqrt(1 + (min/max)²).
# let x = scalar float32 3. in
let y = scalar float32 4. in
hypot x y |> get_item []
- : float = 5.
# let x = scalar float64 1e200 in
let y = scalar float64 1e200 in
hypot x y |> get_item [] < Float.infinity
- : bool = true
Sourceval trunc : ('a, 'b) t -> ('a, 'b) t
trunc t
rounds toward zero.
Removes fractional part. Positive values round down, negative round up.
# let x = create float32 [| 3 |] [| 2.7; -2.7; 2.0 |] in
trunc x
- : (float, float32_elt) t = [2, -2, 2]
Sourceval ceil : (float, 'a) t -> (float, 'a) t
ceil t
rounds up to nearest integer.
Smallest integer not less than input.
# let x = create float32 [| 4 |] [| 2.1; 2.9; -2.1; -2.9 |] in
ceil x
- : (float, float32_elt) t = [3, 3, -2, -2]
Sourceval floor : (float, 'a) t -> (float, 'a) t
floor t
rounds down to nearest integer.
Largest integer not greater than input.
# let x = create float32 [| 4 |] [| 2.1; 2.9; -2.1; -2.9 |] in
floor x
- : (float, float32_elt) t = [2, 2, -3, -3]
Sourceval round : (float, 'a) t -> (float, 'a) t
round t
rounds to nearest integer (half away from zero).
Ties round away from zero (not banker's rounding).
# let x = create float32 [| 4 |] [| 2.5; 3.5; -2.5; -3.5 |] in
round x
- : (float, float32_elt) t = [3, 4, -3, -4]
Sourceval lerp : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
lerp start end_ weight
computes linear interpolation.
Returns start + weight * (end_ - start). Weight typically in 0, 1
.
# let start = scalar float32 0. in
let end_ = scalar float32 10. in
let weight = scalar float32 0.3 in
lerp start end_ weight |> get_item []
- : float = 3.
# let start = create float32 [| 2 |] [| 1.; 2. |] in
let end_ = create float32 [| 2 |] [| 5.; 8. |] in
let weight = create float32 [| 2 |] [| 0.25; 0.5 |] in
lerp start end_ weight
- : (float, float32_elt) t = [2, 5]
Sourceval lerp_scalar_weight : ('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t
lerp_scalar_weight start end_ weight
interpolates with scalar weight.
Comparison and Logical Operations
Element-wise comparisons and logical operations.
cmplt t1 t2
returns 1 where t1 < t2, 0 elsewhere.
less t1 t2
is synonym for cmplt
.
cmpne t1 t2
returns 1 where t1 ≠ t2, 0 elsewhere.
not_equal t1 t2
is synonym for cmpne
.
cmpeq t1 t2
returns 1 where t1 = t2, 0 elsewhere.
equal t1 t2
is synonym for cmpeq
.
cmpgt t1 t2
returns 1 where t1 > t2, 0 elsewhere.
greater t1 t2
is synonym for cmpgt
.
cmple t1 t2
returns 1 where t1 ≤ t2, 0 elsewhere.
less_equal t1 t2
is synonym for cmple
.
cmpge t1 t2
returns 1 where t1 ≥ t2, 0 elsewhere.
greater_equal t1 t2
is synonym for cmpge
.
array_equal t1 t2
returns scalar 1 if all elements equal, 0 otherwise.
Broadcasts inputs before comparison. Returns 0 if shapes incompatible.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let y = create int32 [| 3 |] [| 1l; 2l; 3l |] in
array_equal x y |> get_item []
- : int = 1
# let x = create int32 [| 2 |] [| 1l; 2l |] in
let y = create int32 [| 2 |] [| 1l; 3l |] in
array_equal x y |> get_item []
- : int = 0
Sourceval maximum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
maximum t1 t2
returns element-wise maximum.
Sourceval maximum_s : ('a, 'b) t -> 'a -> ('a, 'b) t
maximum_s t scalar
returns maximum of each element and scalar.
Sourceval rmaximum_s : 'a -> ('a, 'b) t -> ('a, 'b) t
rmaximum_s scalar t
is maximum_s t scalar
.
Sourceval imaximum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
imaximum target value
computes maximum in-place.
Sourceval imaximum_s : ('a, 'b) t -> 'a -> ('a, 'b) t
imaximum_s t scalar
computes maximum with scalar in-place.
Sourceval minimum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
minimum t1 t2
returns element-wise minimum.
Sourceval minimum_s : ('a, 'b) t -> 'a -> ('a, 'b) t
minimum_s t scalar
returns minimum of each element and scalar.
Sourceval rminimum_s : 'a -> ('a, 'b) t -> ('a, 'b) t
rminimum_s scalar t
is minimum_s t scalar
.
Sourceval iminimum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
iminimum target value
computes minimum in-place.
Sourceval iminimum_s : ('a, 'b) t -> 'a -> ('a, 'b) t
iminimum_s t scalar
computes minimum with scalar in-place.
Sourceval logical_and : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
logical_and t1 t2
computes element-wise AND.
Non-zero values are true.
Sourceval logical_or : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
logical_or t1 t2
computes element-wise OR.
Sourceval logical_xor : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
logical_xor t1 t2
computes element-wise XOR.
Sourceval logical_not : ('a, 'b) t -> ('a, 'b) t
logical_not t
computes element-wise NOT.
Returns 1 - x. Non-zero values become 0, zero becomes 1.
# let x = create int32 [| 3 |] [| 0l; 1l; 5l |] in
logical_not x
- : (int32, int32_elt) t = [1, 0, -4]
isinf t
returns 1 where infinite, 0 elsewhere.
Detects both positive and negative infinity. Non-float types return all 0s.
# let x = create float32 [| 4 |] [| 1.; Float.infinity; Float.neg_infinity; Float.nan |] in
isinf x
- : (int, uint8_elt) t = [0, 1, 1, 0]
isnan t
returns 1 where NaN, 0 elsewhere.
NaN is the only value that doesn't equal itself. Non-float types return all 0s.
# let x = create float32 [| 3 |] [| 1.; Float.nan; Float.infinity |] in
isnan x
- : (int, uint8_elt) t = [0, 1, 0]
isfinite t
returns 1 where finite, 0 elsewhere.
Finite means not inf, -inf, or NaN. Non-float types return all 1s.
# let x = create float32 [| 4 |] [| 1.; Float.infinity; Float.nan; -0. |] in
isfinite x
- : (int, uint8_elt) t = [1, 0, 0, 1]
where cond if_true if_false
selects elements based on condition.
Returns if_true
where cond
is non-zero, if_false
elsewhere. All three inputs broadcast to common shape.
# let cond = create uint8 [| 3 |] [| 1; 0; 1 |] in
let if_true = create int32 [| 3 |] [| 2l; 3l; 4l |] in
let if_false = create int32 [| 3 |] [| 5l; 6l; 7l |] in
where cond if_true if_false
- : (int32, int32_elt) t = [2, 6, 4]
# let x = create float32 [| 4 |] [| -1.; 2.; -3.; 4. |] in
where (cmpgt x (scalar float32 0.)) x (scalar float32 0.)
- : (float, float32_elt) t = [0, 2, 0, 4]
Sourceval clamp : ?min:'a -> ?max:'a -> ('a, 'b) t -> ('a, 'b) t
clamp ?min ?max t
limits values to range.
Elements below min
become min
, above max
become max
.
Sourceval clip : ?min:'a -> ?max:'a -> ('a, 'b) t -> ('a, 'b) t
clip ?min ?max t
is synonym for clamp
.
Bitwise Operations
Bitwise operations on integer arrays.
Sourceval bitwise_xor : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
bitwise_xor t1 t2
computes element-wise XOR.
Sourceval bitwise_or : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
bitwise_or t1 t2
computes element-wise OR.
Sourceval bitwise_and : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
bitwise_and t1 t2
computes element-wise AND.
Sourceval bitwise_not : ('a, 'b) t -> ('a, 'b) t
bitwise_not t
computes element-wise NOT.
Sourceval invert : ('a, 'b) t -> ('a, 'b) t
Sourceval lshift : ('a, 'b) t -> int -> ('a, 'b) t
lshift t shift
left-shifts elements by shift
bits.
Equivalent to multiplication by 2^shift. Overflow wraps around.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
lshift x 2
- : (int32, int32_elt) t = [4, 8, 12]
Sourceval rshift : ('a, 'b) t -> int -> ('a, 'b) t
rshift t shift
right-shifts elements by shift
bits.
Equivalent to integer division by 2^shift (rounds toward zero).
# let x = create int32 [| 3 |] [| 8l; 9l; 10l |] in
rshift x 2
- : (int32, int32_elt) t = [2, 2, 2]
Reduction Operations
Functions that reduce array dimensions.
Sourceval sum : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
sum ?axes ?keepdims t
sums elements along specified axes.
Default sums all axes (returns scalar). If keepdims
is true, retains reduced dimensions with size 1. Negative axes count from end.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
sum x |> get_item []
- : float = 10.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
sum ~axes:[| 0 |] x
- : (float, float32_elt) t = [4, 6]
# let x = create float32 [| 1; 2 |] [| 1.; 2. |] in
sum ~axes:[| 1 |] ~keepdims:true x
- : (float, float32_elt) t = [[3]]
# let x = create float32 [| 1; 3 |] [| 1.; 2.; 3. |] in
sum ~axes:[| -1 |] x
- : (float, float32_elt) t = [6]
Sourceval max : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
max ?axes ?keepdims t
finds maximum along axes.
Default reduces all axes. NaN propagates (any NaN input gives NaN output).
# let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
max x |> get_item []
- : float = 6.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
max ~axes:[| 0 |] x
- : (float, float32_elt) t = [3, 4]
# let x = create float32 [| 1; 2 |] [| 1.; 2. |] in
max ~axes:[| 1 |] ~keepdims:true x
- : (float, float32_elt) t = [[2]]
Sourceval min : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
min ?axes ?keepdims t
finds minimum along axes.
Default reduces all axes. NaN propagates (any NaN input gives NaN output).
# let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
min x |> get_item []
- : float = 1.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
min ~axes:[| 0 |] x
- : (float, float32_elt) t = [1, 2]
Sourceval prod : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
prod ?axes ?keepdims t
computes product along axes.
Default multiplies all elements. Empty axes give 1.
# let x = create int32 [| 3 |] [| 2l; 3l; 4l |] in
prod x |> get_item []
- : int32 = 24l
# let x = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
prod ~axes:[| 0 |] x
- : (int32, int32_elt) t = [3, 8]
Sourceval mean : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
mean ?axes ?keepdims t
computes arithmetic mean along axes.
Sum of elements divided by count. NaN propagates.
# let x = create float32 [| 4 |] [| 1.; 2.; 3.; 4. |] in
mean x |> get_item []
- : float = 2.5
# let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
mean ~axes:[| 1 |] x
- : (float, float32_elt) t = [2, 5]
Sourceval var :
?axes:int array ->
?keepdims:bool ->
?ddof:int ->
('a, 'b) t ->
('a, 'b) t
var ?axes ?keepdims ?ddof t
computes variance along axes.
ddof
is delta degrees of freedom. Default 0 (population variance). Use 1 for sample variance. Variance = E(X - E[X])²
/ (N - ddof).
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
var x |> get_item []
- : float = 2.
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
var ~ddof:1 x |> get_item []
- : float = 2.5
Sourceval std :
?axes:int array ->
?keepdims:bool ->
?ddof:int ->
('a, 'b) t ->
('a, 'b) t
std ?axes ?keepdims ?ddof t
computes standard deviation.
Square root of variance: sqrt(var(t, ddof)). See var
for ddof meaning.
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
std x |> get_item [] |> Float.round
- : float = 1.
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
std ~ddof:1 x |> get_item [] |> Float.round
- : float = 2.
Sourceval all : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> (int, uint8_elt) t
all ?axes ?keepdims t
tests if all elements are true (non-zero).
Returns 1 if all elements along axes are non-zero, 0 otherwise.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
all x |> get_item []
- : int = 1
# let x = create int32 [| 3 |] [| 1l; 0l; 3l |] in
all x |> get_item []
- : int = 0
# let x = create int32 [| 2; 2 |] [| 1l; 0l; 1l; 1l |] in
all ~axes:[| 1 |] x
- : (int, uint8_elt) t = [0, 1]
Sourceval any : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> (int, uint8_elt) t
any ?axes ?keepdims t
tests if any element is true (non-zero).
Returns 1 if any element along axes is non-zero, 0 if all are zero.
# let x = create int32 [| 3 |] [| 0l; 0l; 1l |] in
any x |> get_item []
- : int = 1
# let x = create int32 [| 3 |] [| 0l; 0l; 0l |] in
any x |> get_item []
- : int = 0
# let x = create int32 [| 2; 2 |] [| 0l; 0l; 0l; 1l |] in
any ~axes:[| 1 |] x
- : (int, uint8_elt) t = [0, 1]
argmax ?axis ?keepdims t
finds indices of maximum values.
Returns index of first occurrence for ties. If axis
not specified, operates on flattened tensor and returns scalar.
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
argmax x |> get_item []
- : int32 = 4l
# let x = create int32 [| 2; 3 |] [| 1l; 5l; 3l; 2l; 4l; 6l |] in
argmax ~axis:1 x
- : (int32, int32_elt) t = [1, 2]
argmin ?axis ?keepdims t
finds indices of minimum values.
Returns index of first occurrence for ties. If axis
not specified, operates on flattened tensor and returns scalar.
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
argmin x |> get_item []
- : int32 = 1l
# let x = create int32 [| 2; 3 |] [| 5l; 2l; 3l; 1l; 4l; 0l |] in
argmin ~axis:1 x
- : (int32, int32_elt) t = [1, 2]
Sorting and Searching
Functions for sorting arrays and finding indices.
Sourceval sort :
?descending:bool ->
?axis:int ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t
sort ?descending ?axis t
sorts elements along axis.
Returns (sorted_values, indices) where indices map sorted positions to original positions. Default sorts last axis in ascending order.
Algorithm: Bitonic sort (parallel-friendly, stable)
- Pads to power of 2 with inf/-inf for correctness
- O(n log² n) comparisons, O(log² n) depth
- Stable: preserves relative order of equal elements
- First occurrence wins for duplicate values
Special values:
- NaN: sorted to end (ascending) or beginning (descending)
- inf/-inf: sorted normally
- For integers: uses max/min values for padding
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
sort x
- : (int32, int32_elt) t * (int32, int32_elt) t =
([1, 1, 3, 4, 5], [1, 3, 0, 2, 4])
# let x = create int32 [| 2; 2 |] [| 3l; 1l; 1l; 4l |] in
sort ~descending:true ~axis:0 x
- : (int32, int32_elt) t * (int32, int32_elt) t =
([[3, 4],
[1, 1]], [[0, 1],
[1, 0]])
# let x = create float32 [| 4 |] [| Float.nan; 1.; 2.; Float.nan |] in
let v, _ = sort x in
v
- : (float, float32_elt) t = [1, 2, nan, nan]
Sourceval argsort :
?descending:bool ->
?axis:int ->
('a, 'b) t ->
(int32, int32_elt) t
argsort ?descending ?axis t
returns indices that would sort tensor.
Equivalent to snd (sort ?descending ?axis t)
. Returns indices such that taking elements at these indices yields sorted array.
For 1-D: resulti
is the index of the i-th smallest element. For N-D: sorts along specified axis independently.
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
argsort x
- : (int32, int32_elt) t = [1, 3, 0, 2, 4]
# let x = create int32 [| 2; 3 |] [| 3l; 1l; 4l; 2l; 5l; 0l |] in
argsort ~axis:1 x
- : (int32, int32_elt) t = [[1, 0, 2],
[2, 0, 1]]
Linear Algebra
Matrix operations and linear algebra functions.
Sourceval dot : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
dot a b
computes generalized dot product.
Contracts last axis of a
with:
- 1-D
b
: the only axis (axis 0) - N-D
b
: second-to-last axis (axis -2)
Dimension rules:
- 1-D × 1-D: inner product, returns scalar
- 2-D × 2-D: matrix multiplication
- N-D × M-D: batched contraction over all but contracted axes
Supports broadcasting on batch dimensions. Result shape is concatenation of:
- Broadcasted batch dims
- Remaining dims from
a
(except last) - Remaining dims from
b
(except contracted axis)
# let a = create float32 [| 2 |] [| 1.; 2. |] in
let b = create float32 [| 2 |] [| 3.; 4. |] in
dot a b |> get_item []
- : float = 11.
# let a = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
let b = create float32 [| 2; 2 |] [| 5.; 6.; 7.; 8. |] in
dot a b
- : (float, float32_elt) t = [[19, 22],
[43, 50]]
# dot (ones float32 [| 3; 4; 5 |]) (ones float32 [| 5; 6 |]) |> shape
- : int array = [|3; 4; 6|]
# dot (ones float32 [| 2; 3; 4; 5 |]) (ones float32 [| 3; 5; 6 |]) |> shape
- : int array = [|2; 3; 4; 6|]
Sourceval matmul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
matmul a b
computes matrix multiplication with broadcasting.
Follows NumPy's @ operator semantics:
- 1-D × 1-D: inner product (returns scalar tensor)
- 1-D × N-D: treated as
1 × k
@ ... × k × n
→ ... × n
- N-D × 1-D: treated as
... × m × k
@ k × 1
→ ... × m
- N-D × M-D: batched matrix multiply on last 2 dimensions
Broadcasting rules:
- All dimensions except last 2 are broadcast together
- For 1-D inputs, dimension is temporarily added then removed
- Inner dimensions must match: a.shape
-1
== b.shape-2
Result shape:
- Batch dims: broadcast(a.shape
:-2
, b.shape:-2
) - Matrix dims:
..., a.shape[-2], b.shape[-1]
- 1-D adjustments applied after
# let a = create float32 [| 3 |] [| 1.; 2.; 3. |] in
let b = create float32 [| 3 |] [| 4.; 5.; 6. |] in
matmul a b |> get_item []
- : float = 32.
# let a = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
let b = create float32 [| 2 |] [| 5.; 6. |] in
matmul a b
- : (float, float32_elt) t = [17, 39]
# let a = create float32 [| 2 |] [| 1.; 2. |] in
let b = create float32 [| 2; 3 |] [| 3.; 4.; 5.; 6.; 7.; 8. |] in
matmul a b
- : (float, float32_elt) t = [15, 18, 21]
# matmul (ones float32 [| 10; 3; 4 |]) (ones float32 [| 10; 4; 5 |]) |> shape
- : int array = [|10; 3; 5|]
# matmul (ones float32 [| 1; 3; 4 |]) (ones float32 [| 5; 4; 2 |]) |> shape
- : int array = [|5; 3; 2|]
Activation Functions
Neural network activation functions.
Sourceval relu : ('a, 'b) t -> ('a, 'b) t
relu t
applies Rectified Linear Unit: max(0, x).
# let x = create float32 [| 5 |] [| -2.; -1.; 0.; 1.; 2. |] in
relu x
- : (float, float32_elt) t = [0, 0, 0, 1, 2]
Sourceval relu6 : (float, 'a) t -> (float, 'a) t
relu6 t
applies ReLU6: min(max(0, x), 6).
Bounded ReLU used in mobile networks. Clips to 0, 6
range.
# let x = create float32 [| 3 |] [| -1.; 3.; 8. |] in
relu6 x
- : (float, float32_elt) t = [0, 3, 6]
Sourceval sigmoid : (float, 'a) t -> (float, 'a) t
sigmoid t
applies logistic sigmoid: 1 / (1 + exp(-x)).
Output in range (0, 1). Symmetric around x=0 where sigmoid(0) = 0.5.
# sigmoid (scalar float32 0.) |> get_item []
- : float = 0.5
# sigmoid (scalar float32 10.) |> get_item [] |> Float.round
- : float = 1.
# sigmoid (scalar float32 (-10.)) |> get_item [] |> Float.round
- : float = 0.
Sourceval hard_sigmoid :
?alpha:float ->
?beta:float ->
(float, 'a) t ->
(float, 'a) t
hard_sigmoid ?alpha ?beta t
applies piecewise linear sigmoid approximation.
Default alpha = 1/6
, beta = 0.5
.
Sourceval softplus : (float, 'a) t -> (float, 'a) t
softplus t
applies smooth ReLU: log(1 + exp(x)).
Smooth approximation to ReLU. Always positive, differentiable everywhere.
# softplus (scalar float32 0.) |> get_item [] |> Float.round
- : float = 1.
# softplus (scalar float32 100.) |> get_item [] |> Float.round
- : float = infinity
Sourceval silu : (float, 'a) t -> (float, 'a) t
silu t
applies Sigmoid Linear Unit: x * sigmoid(x).
Also called Swish. Smooth, non-monotonic activation.
# silu (scalar float32 0.) |> get_item []
- : float = 0.
# silu (scalar float32 1.) |> get_item [] |> Float.round
- : float = 1.
# silu (scalar float32 (-1.)) |> get_item [] |> Float.round
- : float = -0.
Sourceval hard_silu : (float, 'a) t -> (float, 'a) t
hard_silu t
applies x * hard_sigmoid(x).
Piecewise linear approximation of SiLU. More efficient than SiLU.
# let x = create float32 [| 3 |] [| -3.; 0.; 3. |] in
hard_silu x
- : (float, float32_elt) t = [-0, 0, 3]
Sourceval log_sigmoid : (float, 'a) t -> (float, 'a) t
log_sigmoid t
computes log(sigmoid(x)).
Numerically stable version of log(1/(1+exp(-x))). Always negative.
# log_sigmoid (scalar float32 0.) |> get_item [] |> Float.round
- : float = -1.
# log_sigmoid (scalar float32 100.) |> get_item [] |> Float.abs |> (fun x -> x < 0.001)
- : bool = true
Sourceval leaky_relu : ?negative_slope:float -> (float, 'a) t -> (float, 'a) t
leaky_relu ?negative_slope t
applies Leaky ReLU.
Default negative_slope = 0.01
. Returns x if x > 0, else negative_slope * x.
Sourceval hard_tanh : (float, 'a) t -> (float, 'a) t
hard_tanh t
clips values to -1, 1
.
Linear in -1, 1
, saturates outside. Cheaper than tanh.
# let x = create float32 [| 5 |] [| -2.; -0.5; 0.; 0.5; 2. |] in
hard_tanh x
- : (float, float32_elt) t = [-1, -0.5, 0, 0.5, 1]
Sourceval elu : ?alpha:float -> (float, 'a) t -> (float, 'a) t
elu ?alpha t
applies Exponential Linear Unit.
Default alpha = 1.0
. Returns x if x > 0, else alpha * (exp(x) - 1). Smooth for x < 0, helps with vanishing gradients.
# elu (scalar float32 1.) |> get_item []
- : float = 1.
# elu (scalar float32 0.) |> get_item []
- : float = 0.
# elu (scalar float32 (-1.)) |> get_item [] |> Float.round
- : float = -1.
Sourceval selu : (float, 'a) t -> (float, 'a) t
selu t
applies Scaled ELU with fixed alpha=1.67326, lambda=1.0507.
Self-normalizing activation. Preserves mean 0 and variance 1 in deep networks under certain conditions.
# selu (scalar float32 0.) |> get_item []
- : float = 0.
# selu (scalar float32 1.) |> get_item [] |> Float.round
- : float = 1.
Sourceval softmax : ?axes:int array -> (float, 'a) t -> (float, 'a) t
softmax ?axes t
applies softmax normalization.
Default axis -1. Computes exp(x - max) / sum(exp(x - max)) for numerical stability. Output sums to 1 along specified axes.
# let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
softmax x |> to_array |> Array.map Float.round
- : float array = [|0.; 0.; 1.|]
# let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
sum (softmax x) |> get_item []
- : float = 1.
Sourceval gelu_approx : (float, 'a) t -> (float, 'a) t
gelu_approx t
applies Gaussian Error Linear Unit approximation.
Smooth activation: x * Φ(x) where Φ is Gaussian CDF. This uses tanh approximation for efficiency.
# gelu_approx (scalar float32 0.) |> get_item []
- : float = 0.
# gelu_approx (scalar float32 1.) |> get_item [] |> Float.round
- : float = 1.
Sourceval softsign : (float, 'a) t -> (float, 'a) t
softsign t
computes x / (|x| + 1).
Similar to tanh but computationally cheaper. Range (-1, 1).
# let x = create float32 [| 3 |] [| -10.; 0.; 10. |] in
softsign x
- : (float, float32_elt) t = [-0.909091, 0, 0.909091]
Sourceval mish : (float, 'a) t -> (float, 'a) t
mish t
applies Mish activation: x * tanh(softplus(x)).
Self-regularizing non-monotonic activation. Smoother than ReLU.
# mish (scalar float32 0.) |> get_item [] |> Float.abs |> (fun x -> x < 0.001)
- : bool = true
# mish (scalar float32 (-10.)) |> get_item [] |> Float.round
- : float = -0.
Convolution and Pooling
Neural network convolution and pooling operations.
Sourceval im2col :
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t ->
('a, 'b) t
im2col ~kernel_size ~stride ~dilation ~padding t
extracts sliding local blocks from tensor.
Extracts patches of size kernel_size from the input tensor at the specified stride and dilation.
kernel_size
: size of sliding blocks to extractstride
: step between consecutive blocksdilation
: spacing between kernel elementspadding
: (before, after) padding for each spatial dimension
For a 4D input batch; channels; height; width
, produces output shape batch; channels * kh * kw; num_patches_h; num_patches_w
where kh, kw are kernel dimensions and num_patches depends on stride and padding.
# let x = arange_f float32 0. 16. 1. |> reshape [| 1; 1; 4; 4 |] in
im2col ~kernel_size:[|2; 2|] ~stride:[|1; 1|]
~dilation:[|1; 1|] ~padding:[|(0, 0); (0, 0)|] x |> shape
- : int array = [|1; 4; 9|]
Sourceval col2im :
output_size:int array ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t ->
('a, 'b) t
col2im ~output_size ~kernel_size ~stride ~dilation ~padding t
combines sliding local blocks into tensor.
This is the inverse of im2col
. Accumulates values from the unfolded representation back into spatial dimensions. Overlapping regions are summed.
output_size
: target spatial dimensions height; width
kernel_size
: size of sliding blocksstride
: step between consecutive blocksdilation
: spacing between kernel elementspadding
: (before, after) padding for each spatial dimension
For input shape batch; channels * kh * kw; num_patches_h; num_patches_w
, produces output batch; channels; height; width
.
# let unfolded = create float32 [| 1; 4; 3; 3 |] (Array.init 36 Float.of_int) in
col2im ~output_size:[|4; 4|] ~kernel_size:[|2; 2|]
~stride:[|1; 1|] ~dilation:[|1; 1|]
~padding:[|(0, 0); (0, 0)|] unfolded |> shape
- : int array = [|1; 4; 0; 4; 4|]
Sourceval correlate1d :
?groups:int ->
?stride:int ->
?padding_mode:[ `Full | `Same | `Valid ] ->
?dilation:int ->
?fillvalue:float ->
?bias:(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t
correlate1d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w
computes 1D cross-correlation (no kernel flip).
x
: input batch_size; channels_in; width
w
: weights channels_out; channels_in/groups; kernel_width
bias
: optional per-channel bias channels_out
groups
: split input/output channels into groups (default 1)stride
: step between windows (default 1)padding_mode
: `Valid (no pad), `Same (preserve size), `Full (all overlaps)dilation
: spacing between kernel elements (default 1)fillvalue
: padding value (default 0.0)
Output width depends on padding:
- `Valid: (width - dilation*(kernel-1) - 1)/stride + 1
- `Same: width/stride (rounded up)
- `Full: (width + dilation*(kernel-1) - 1)/stride + 1
# let x = create float32 [| 1; 1; 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
let w = create float32 [| 1; 1; 3 |] [| 1.; 0.; -1. |] in
correlate1d x w |> shape
- : int array = [|1; 1; 3|]
Sourceval correlate2d :
?groups:int ->
?stride:(int * int) ->
?padding_mode:[ `Full | `Same | `Valid ] ->
?dilation:(int * int) ->
?fillvalue:float ->
?bias:(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t
correlate2d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w
computes 2D cross-correlation (no kernel flip).
x
: input batch; channels_in; height; width
w
: weights channels_out; channels_in/groups; kernel_h; kernel_w
bias
: optional per-channel bias channels_out
stride
: (stride_h, stride_w) step between windows (default (1,1))dilation
: (dilation_h, dilation_w) kernel spacing (default (1,1))padding_mode
: `Valid (no pad), `Same (preserve size), `Full (all overlaps)
Uses Winograd F(4,3) for 3×3 kernels with stride 1 when beneficial. For `Same` with even kernels, pads more on bottom/right (SciPy convention).
# let image = ones float32 [| 1; 1; 5; 5 |] in
let sobel_x = create float32 [| 1; 1; 3; 3 |] [| 1.; 0.; -1.; 2.; 0.; -2.; 1.; 0.; -1. |] in
correlate2d image sobel_x |> shape
- : int array = [|1; 1; 3; 3|]
Sourceval convolve1d :
?groups:int ->
?stride:int ->
?padding_mode:[< `Full | `Same | `Valid Valid ] ->
?dilation:int ->
?fillvalue:'a ->
?bias:('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t
convolve1d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w
computes 1D convolution (flips kernel before correlation).
Same parameters as correlate1d
but flips kernel. For `Same` with even kernels, pads more on left (NumPy convention).
# let x = create float32 [| 1; 1; 3 |] [| 1.; 2.; 3. |] in
let w = create float32 [| 1; 1; 2 |] [| 4.; 5. |] in
convolve1d x w
- : (float, float32_elt) t = [[[13, 22]]]
Sourceval convolve2d :
?groups:int ->
?stride:(int * int) ->
?padding_mode:[< `Full | `Same | `Valid Valid ] ->
?dilation:(int * int) ->
?fillvalue:'a ->
?bias:('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t
convolve2d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w
computes 2D convolution (flips kernel before correlation).
Same parameters as correlate2d
but flips kernel horizontally and vertically. For `Same` with even kernels, pads more on top/left.
# let image = ones float32 [| 1; 1; 5; 5 |] in
let gaussian = create float32 [| 1; 1; 3; 3 |] [| 1.; 2.; 1.; 2.; 4.; 2.; 1.; 2.; 1. |] in
convolve2d image (mul_s gaussian (1. /. 16.)) |> shape
- : int array = [|1; 1; 3; 3|]
Sourceval avg_pool1d :
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?count_include_pad:bool ->
(float, 'a) t ->
(float, 'a) t
avg_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?count_include_pad x
applies 1D average pooling.
kernel_size
: pooling window sizestride
: step between windows (default: kernel_size)dilation
: spacing between kernel elements (default 1)padding_spec
: same as convolution padding modesceil_mode
: use ceiling for output size calculation (default false)count_include_pad
: include padding in average (default true)
Input shape: batch; channels; width
Output width: (width + 2*pad - dilation*(kernel-1) - 1)/stride + 1
# let x = create float32 [| 1; 1; 4 |] [| 1.; 2.; 3.; 4. |] in
avg_pool1d ~kernel_size:2 x
- : (float, float32_elt) t = [[[1.5, 3.5]]]
Sourceval avg_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?count_include_pad:bool ->
(float, 'a) t ->
(float, 'a) t
avg_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?count_include_pad x
applies 2D average pooling.
kernel_size
: (height, width) of pooling windowstride
: (stride_h, stride_w) (default: kernel_size)dilation
: (dilation_h, dilation_w) (default (1,1))count_include_pad
: whether padding contributes to denominator
Input shape: batch; channels; height; width
# let x = create float32 [| 1; 1; 2; 2 |] [| 1.; 2.; 3.; 4. |] in
avg_pool2d ~kernel_size:(2, 2) x
- : (float, float32_elt) t = [[[[2.5]]]]
Sourceval max_pool1d :
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t option
max_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x
applies 1D max pooling.
return_indices
: if true, also returns indices of max values for unpooling- Other parameters same as
avg_pool1d
Returns (pooled_values, Some indices) if return_indices=true, otherwise (pooled_values, None). Indices are flattened positions in input.
# let x = create float32 [| 1; 1; 4 |] [| 1.; 3.; 2.; 4. |] in
let vals, idx = max_pool1d ~kernel_size:2 ~return_indices:true x in
vals, idx
- : (float, float32_elt) t * (int32, int32_elt) t option =
([[[3, 4]]], Some [[[1, 1]]])
Sourceval max_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t option
max_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x
applies 2D max pooling.
Parameters same as max_pool1d
but for 2D. Indices encode flattened position within each pooling window.
# let x = create float32 [| 1; 1; 4; 4 |]
[| 1.; 2.; 5.; 6.; 3.; 4.; 7.; 8.; 9.; 10.; 13.; 14.; 11.; 12.; 15.; 16. |] in
let vals, _ = max_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) x in
vals
- : (float, float32_elt) t = [[[[4, 8],
[12, 16]]]]
Sourceval min_pool1d :
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t option
min_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x
applies 1D min pooling.
return_indices
: if true, also returns indices of min values (currently returns None)- Other parameters same as
avg_pool1d
Returns (pooled_values, None). Index tracking not yet implemented.
# let x = create float32 [| 1; 1; 4 |] [| 4.; 2.; 3.; 1. |] in
let vals, _ = min_pool1d ~kernel_size:2 x in
vals
- : (float, float32_elt) t = [[[2, 1]]]
Sourceval min_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t option
min_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x
applies 2D min pooling.
Parameters same as min_pool1d
but for 2D. Commonly used for morphological erosion operations in image processing.
# let x = create float32 [| 1; 1; 4; 4 |]
[| 1.; 2.; 5.; 6.; 3.; 4.; 7.; 8.; 9.; 10.; 13.; 14.; 11.; 12.; 15.; 16. |] in
let vals, _ = min_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) x in
vals
- : (float, float32_elt) t = [[[[1, 5],
[9, 13]]]]
Sourceval max_unpool1d :
(int, uint8_elt) t ->
('a, 'b) t ->
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?output_size_opt:int array ->
unit ->
(int, uint8_elt) t
max_unpool1d indices values ~kernel_size ?stride ?dilation ?padding_spec ?output_size_opt ()
reverses max pooling.
indices
: indices from max_pool1d with return_indices=truevalues
: pooled values to place at indexed positionskernel_size
, stride
, dilation
, padding_spec
: must match original pooloutput_size_opt
: exact output shape (inferred if not provided)
Places values at positions indicated by indices, fills rest with zeros. Output size computed from input unless explicitly specified.
# let x = create float32 [| 1; 1; 4 |] [| 1.; 3.; 2.; 4. |] in
let pooled, _ = max_pool1d ~kernel_size:2 x in
pooled
- : (float, float32_elt) t = [[[3, 4]]]
Sourceval max_unpool2d :
(int, uint8_elt) t ->
('a, 'b) t ->
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?output_size_opt:int array ->
unit ->
(int, uint8_elt) t
max_unpool2d indices values ~kernel_size ?stride ?dilation ?padding_spec ?output_size_opt ()
reverses 2D max pooling.
Same as max_unpool1d
but for 2D. Indices encode position within each pooling window. Useful for architectures like segmentation networks that need to "remember" where maxima came from.
# let x = create float32 [| 1; 1; 4; 4 |]
[| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.;
9.; 10.; 11.; 12.; 13.; 14.; 15.; 16. |] in
let pooled, _ = max_pool2d ~kernel_size:(2,2) x in
pooled
- : (float, float32_elt) t = [[[[6, 8],
[14, 16]]]]
one_hot ~num_classes indices
creates one-hot encoding.
Adds new last dimension of size num_classes
. Values must be in [0, num_classes)
. Out-of-range indices produce zero vectors.
# let indices = create int32 [| 3 |] [| 0l; 1l; 3l |] in
one_hot ~num_classes:4 indices
- : (int, uint8_elt) t = [[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]]
# let indices = create int32 [| 2; 2 |] [| 0l; 2l; 1l; 0l |] in
one_hot ~num_classes:3 indices |> shape
- : int array = [|2; 2; 3|]
Iteration and Mapping
Functions to iterate over and transform arrays.
Sourceval map_item : ('a -> 'a) -> ('a, 'b) t -> ('a, 'b) t
map_item f t
applies f
to each element.
Operates on contiguous data directly. Type-preserving only.
Sourceval iter_item : ('a -> unit) -> ('a, 'b) t -> unit
iter_item f t
applies f
to each element for side effects.
Sourceval fold_item : ('a -> 'b -> 'a) -> 'a -> ('b, 'c) t -> 'a
fold_item f init t
folds f
over elements.
Sourceval map : (('a, 'b) t -> ('a, 'b) t) -> ('a, 'b) t -> ('a, 'b) t
map f t
applies tensor function f
to each element as scalar tensor.
Sourceval iter : (('a, 'b) t -> unit) -> ('a, 'b) t -> unit
iter f t
applies tensor function f
to each element.
Sourceval fold : ('a -> ('b, 'c) t -> 'a) -> 'a -> ('b, 'c) t -> 'a
fold f init t
folds tensor function over elements.
Printing and Display
Functions to display arrays and convert to strings.
pp_data fmt t
pretty-prints tensor data.
format_to_string pp x
converts using pretty-printer.
print_with_formatter pp x
prints using formatter.
Sourceval data_to_string : ('a, 'b) t -> string
data_to_string t
converts tensor data to string.
Sourceval print_data : ('a, 'b) t -> unit
print_data t
prints tensor data to stdout.
pp_dtype fmt dt
pretty-prints dtype.
dtype_to_string dt
converts dtype to string.
Sourceval shape_to_string : int array -> string
shape_to_string shape
formats shape as "2x3x4
".
pp_shape fmt shape
pretty-prints shape.
pp fmt t
pretty-prints tensor info and data.
Sourceval print : ('a, 'b) t -> unit
print t
prints tensor info and data to stdout.
Sourceval to_string : ('a, 'b) t -> string
to_string t
converts tensor info and data to string.