package nx

  1. Overview
  2. Docs

Module NxSource

N-dimensional array operations for OCaml.

This module provides NumPy-style tensor operations. Tensors are immutable views over mutable buffers, supporting broadcasting, slicing, and efficient memory layout transformations.

Type System

The type ('a, 'b) t represents a tensor where 'a is the OCaml type of elements and 'b is the bigarray element type. For example, (float, float32_elt) t is a tensor of 32-bit floats.

Broadcasting

Operations automatically broadcast compatible shapes: each dimension must be equal or one of them must be 1. Shape |3; 1; 5| broadcasts with |1; 4; 5| to |3; 4; 5|.

Memory Layout

Tensors can be C-contiguous or strided. Operations return views when possible (O(1)), otherwise copy (O(n)). Use is_contiguous to check layout and contiguous to ensure contiguity.

Type Definitions

Sourcetype ('a, 'b) t = ('a, 'b) Nx_native.t

('a, 'b) t is a tensor with OCaml type 'a and bigarray type 'b.

Sourcetype float16_elt = Bigarray.float16_elt
Sourcetype float32_elt = Bigarray.float32_elt
Sourcetype float64_elt = Bigarray.float64_elt
Sourcetype int32_elt = Bigarray.int32_elt
Sourcetype int64_elt = Bigarray.int64_elt
Sourcetype int_elt = Bigarray.int_elt
Sourcetype nativeint_elt = Bigarray.nativeint_elt
Sourcetype complex32_elt = Bigarray.complex32_elt
Sourcetype complex64_elt = Bigarray.complex64_elt
Sourcetype ('a, 'b) dtype = ('a, 'b) Nx_core.Dtype.t =
  1. | Float16 : (float, float16_elt) dtype
  2. | Float32 : (float, float32_elt) dtype
  3. | Float64 : (float, float64_elt) dtype
  4. | Int8 : (int, int8_elt) dtype
  5. | UInt8 : (int, uint8_elt) dtype
  6. | Int16 : (int, int16_elt) dtype
  7. | UInt16 : (int, uint16_elt) dtype
  8. | Int32 : (int32, int32_elt) dtype
  9. | Int64 : (int64, int64_elt) dtype
  10. | Int : (int, int_elt) dtype
  11. | NativeInt : (nativeint, nativeint_elt) dtype
  12. | Complex32 : (Complex.t, complex32_elt) dtype
  13. | Complex64 : (Complex.t, complex64_elt) dtype
    (*

    Data type specification. Links OCaml types to bigarray element types.

    *)
Sourcetype float16_t = (float, float16_elt) t
Sourcetype float32_t = (float, float32_elt) t
Sourcetype float64_t = (float, float64_elt) t
Sourcetype int8_t = (int, int8_elt) t
Sourcetype uint8_t = (int, uint8_elt) t
Sourcetype int16_t = (int, int16_elt) t
Sourcetype uint16_t = (int, uint16_elt) t
Sourcetype int32_t = (int32, int32_elt) t
Sourcetype int64_t = (int64, int64_elt) t
Sourcetype std_int_t = (int, int_elt) t
Sourcetype std_nativeint_t = (nativeint, nativeint_elt) t
Sourcetype complex32_t = (Complex.t, complex32_elt) t
Sourcetype complex64_t = (Complex.t, complex64_elt) t
Sourceval float16 : (float, float16_elt) dtype
Sourceval float32 : (float, float32_elt) dtype
Sourceval float64 : (float, float64_elt) dtype
Sourceval int8 : (int, int8_elt) dtype
Sourceval uint8 : (int, uint8_elt) dtype
Sourceval int16 : (int, int16_elt) dtype
Sourceval uint16 : (int, uint16_elt) dtype
Sourceval int32 : (int32, int32_elt) dtype
Sourceval int64 : (int64, int64_elt) dtype
Sourceval int : (int, int_elt) dtype
Sourceval nativeint : (nativeint, nativeint_elt) dtype
Sourcetype index =
  1. | I of int
    (*

    Single index

    *)
  2. | L of int list
    (*

    List of indices

    *)
  3. | R of int list
    (*

    Range start; stop; step where stop is exclusive

    *)

Index specification for slicing

Array Properties

Functions to inspect array dimensions, memory layout, and data access.

Sourceval data : ('a, 'b) t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t

data t returns underlying bigarray buffer.

Buffer may contain data beyond tensor bounds for strided views. Direct access requires careful index computation using strides and offset.

Sourceval shape : ('a, 'b) t -> int array

shape t returns dimensions. Empty array for scalars.

Sourceval dtype : ('a, 'b) t -> ('a, 'b) dtype

dtype t returns data type.

Sourceval strides : ('a, 'b) t -> int array

strides t returns byte strides for each dimension.

Sourceval stride : int -> ('a, 'b) t -> int

stride i t returns byte stride for dimension i.

Sourceval dims : ('a, 'b) t -> int array

dims t is synonym for shape.

Sourceval dim : int -> ('a, 'b) t -> int

dim i t returns size of dimension i.

Sourceval ndim : ('a, 'b) t -> int

ndim t returns number of dimensions.

Sourceval itemsize : ('a, 'b) t -> int

itemsize t returns bytes per element.

Sourceval size : ('a, 'b) t -> int

size t returns total number of elements.

Sourceval numel : ('a, 'b) t -> int

numel t is synonym for size.

Sourceval nbytes : ('a, 'b) t -> int

nbytes t returns size t * itemsize t.

Sourceval offset : ('a, 'b) t -> int

offset t returns element offset in underlying buffer.

Sourceval is_c_contiguous : ('a, 'b) t -> bool

is_c_contiguous t returns true if elements are contiguous in C order.

Sourceval to_bigarray : ('a, 'b) t -> ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t

to_bigarray t converts to bigarray.

Always returns contiguous copy with same shape. Use for interop with libraries expecting bigarrays.

  # let t = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |]
  val t : (float, float32_elt) t = [[1, 2, 3],
                                    [4, 5, 6]]
  # Bigarray.Genarray.dims (to_bigarray t) = shape t
  - : bool = true
Sourceval to_array : ('a, 'b) t -> 'a array

to_array t converts to OCaml array.

Flattens tensor to 1-D array in row-major (C) order. Always copies.

  # let t = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |]
  val t : (int32, int32_elt) t = [[1, 2],
                                  [3, 4]]
  # to_array t
  - : int32 array = [|1l; 2l; 3l; 4l|]

Array Creation

Functions to create and initialize arrays.

Sourceval create : ('a, 'b) dtype -> int array -> 'a array -> ('a, 'b) t

create dtype shape data creates tensor from array data.

Length of data must equal product of shape.

  # create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |]
  - : (float, float32_elt) t = [[1, 2, 3],
                                [4, 5, 6]]
Sourceval init : ('a, 'b) dtype -> int array -> (int array -> 'a) -> ('a, 'b) t

init dtype shape f creates tensor where element at indices i has value f i.

Function f receives array of indices for each position. Useful for creating position-dependent values.

  # init int32 [| 2; 3 |] (fun i -> Int32.of_int (i.(0) + i.(1)))
  - : (int32, int32_elt) t = [[0, 1, 2],
                              [1, 2, 3]]

  # init float32 [| 3; 3 |] (fun i -> if i.(0) = i.(1) then 1. else 0.)
  - : (float, float32_elt) t = [[1, 0, 0],
                                [0, 1, 0],
                                [0, 0, 1]]
Sourceval empty : ('a, 'b) dtype -> int array -> ('a, 'b) t

empty dtype shape allocates uninitialized tensor.

Sourceval full : ('a, 'b) dtype -> int array -> 'a -> ('a, 'b) t

full dtype shape value creates tensor filled with value.

  # full float32 [| 2; 3 |] 3.14
  - : (float, float32_elt) t = [[3.14, 3.14, 3.14],
                                [3.14, 3.14, 3.14]]
Sourceval ones : ('a, 'b) dtype -> int array -> ('a, 'b) t

ones dtype shape creates tensor filled with ones.

Sourceval zeros : ('a, 'b) dtype -> int array -> ('a, 'b) t

zeros dtype shape creates tensor filled with zeros.

Sourceval scalar : ('a, 'b) dtype -> 'a -> ('a, 'b) t

scalar dtype value creates scalar tensor containing value.

Sourceval empty_like : ('a, 'b) t -> ('a, 'b) t

empty_like t creates uninitialized tensor with same shape and dtype as t.

Sourceval full_like : ('a, 'b) t -> 'a -> ('a, 'b) t

full_like t value creates tensor shaped like t filled with value.

Sourceval ones_like : ('a, 'b) t -> ('a, 'b) t

ones_like t creates tensor shaped like t filled with ones.

Sourceval zeros_like : ('a, 'b) t -> ('a, 'b) t

zeros_like t creates tensor shaped like t filled with zeros.

Sourceval scalar_like : ('a, 'b) t -> 'a -> ('a, 'b) t

scalar_like t value creates scalar with same dtype as t.

Sourceval eye : ?m:int -> ?k:int -> ('a, 'b) dtype -> int -> ('a, 'b) t

eye ?m ?k dtype n creates matrix with ones on k-th diagonal.

Default m = n (square), k = 0 (main diagonal). Positive k shifts diagonal above main, negative below.

  # eye int32 3
  - : (int32, int32_elt) t = [[1, 0, 0],
                              [0, 1, 0],
                              [0, 0, 1]]
  # eye ~k:1 int32 3
  - : (int32, int32_elt) t = [[0, 1, 0],
                              [0, 0, 1],
                              [0, 0, 0]]
  # eye ~m:2 ~k:(-1) int32 3
  - : (int32, int32_elt) t = [[0, 0, 0],
                              [1, 0, 0]]
Sourceval identity : ('a, 'b) dtype -> int -> ('a, 'b) t

identity dtype n creates n×n identity matrix.

Equivalent to eye dtype n. Square matrix with ones on main diagonal, zeros elsewhere.

  # identity int32 3
  - : (int32, int32_elt) t = [[1, 0, 0],
                              [0, 1, 0],
                              [0, 0, 1]]
Sourceval arange : ('a, 'b) dtype -> int -> int -> int -> ('a, 'b) t

arange dtype start stop step generates values from start to [stop).

Step must be non-zero. Result length is (stop - start) / step rounded toward zero.

  # arange int32 0 10 2
  - : (int32, int32_elt) t = [0, 2, 4, 6, 8]
  # arange int32 5 0 (-1)
  - : (int32, int32_elt) t = [5, 4, 3, 2, 1]
Sourceval arange_f : (float, 'a) dtype -> float -> float -> float -> (float, 'a) t

arange_f dtype start stop step generates float values from start to [stop).

Like arange but for floating-point ranges. Handles fractional steps. Due to floating-point precision, final value may differ slightly from expected.

  # arange_f float32 0. 1. 0.2
  - : (float, float32_elt) t = [0, 0.2, 0.4, 0.6, 0.8]
  # arange_f float32 1. 0. (-0.25)
  - : (float, float32_elt) t = [1, 0.75, 0.5, 0.25]
Sourceval linspace : ('a, 'b) dtype -> ?endpoint:bool -> float -> float -> int -> ('a, 'b) t

linspace dtype ?endpoint start stop count generates count evenly spaced values from start to stop.

If endpoint is true (default), stop is included.

  # linspace float32 ~endpoint:true 0. 10. 5
  - : (float, float32_elt) t = [0, 2.5, 5, 7.5, 10]
  # linspace float32 ~endpoint:false 0. 10. 5
  - : (float, float32_elt) t = [0, 2, 4, 6, 8]
Sourceval logspace : (float, 'a) dtype -> ?endpoint:bool -> ?base:float -> float -> float -> int -> (float, 'a) t

logspace dtype ?endpoint ?base start_exp stop_exp count generates values evenly spaced on log scale.

Returns base ** x where x ranges from start_exp to stop_exp. Default base = 10.0.

  # logspace float32 0. 2. 3
  - : (float, float32_elt) t = [1, 10, 100]
  # logspace float32 ~base:2.0 0. 3. 4
  - : (float, float32_elt) t = [1, 2, 4, 8]
Sourceval geomspace : (float, 'a) dtype -> ?endpoint:bool -> float -> float -> int -> (float, 'a) t

geomspace dtype ?endpoint start stop count generates values evenly spaced on geometric (multiplicative) scale.

  # geomspace float32 1. 1000. 4
  - : (float, float32_elt) t = [1, 10, 100, 1000]
Sourceval meshgrid : ?indexing:[ `xy | `ij ] -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t

meshgrid ?indexing x y creates coordinate grids from 1D arrays.

Returns (X, Y) where X and Y are 2D arrays representing grid coordinates.

  • `xy (default): Cartesian indexing - X changes along columns, Y changes along rows
  • `ij: Matrix indexing - X changes along rows, Y changes along columns
  # let x = linspace float32 0. 2. 3 in
    let y = linspace float32 0. 1. 2 in
    meshgrid x y
  - : (float, float32_elt) t * (float, float32_elt) t =
  ([[0, 1, 2],
    [0, 1, 2]], [[0, 0, 0],
                 [1, 1, 1]])
Sourceval of_bigarray : ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t -> ('a, 'b) t

of_bigarray ba creates tensor from bigarray.

Zero-copy when bigarray is contiguous. Creates view sharing same memory. Modifications to either affect both.

  # let ba = Bigarray.Array2.create Float32 C_layout 2 3 in
    let t = of_bigarray (Bigarray.genarray_of_array2 ba) in
    t
  - : (float, float32_elt) t = [[0, 0, 0],
                                [0, 0, 0]]

Random Number Generation

Functions to generate arrays with random values.

Sourceval rand : ('a, 'b) dtype -> ?seed:int -> int array -> ('a, 'b) t

rand dtype ?seed shape generates uniform random values in [0, 1).

Only supports float dtypes. Same seed produces same sequence.

Sourceval randn : ('a, 'b) dtype -> ?seed:int -> int array -> ('a, 'b) t

randn dtype ?seed shape generates standard normal random values.

Mean 0, variance 1. Uses Box-Muller transform for efficiency. Only supports float dtypes. Same seed produces same sequence.

Sourceval randint : ('a, 'b) dtype -> ?seed:int -> ?high:int -> int array -> int -> ('a, 'b) t

randint dtype ?seed ?high shape low generates integers in [low, high).

Uniform distribution over range. Default high = 10. Note: high is exclusive (NumPy convention).

Shape Manipulation

Functions to reshape, transpose, and rearrange arrays.

Sourceval reshape : int array -> ('a, 'b) t -> ('a, 'b) t

reshape shape t returns view with new shape.

At most one dimension can be -1 (inferred from total elements). Product of dimensions must match total elements. Returns view when possible (O(1)), copies if tensor is not contiguous and cannot be viewed.

  # let t = create int32 [|2; 3|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
    reshape [|6|] t
  - : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
  # let t = create int32 [|6|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
    reshape [|3; -1|] t
  - : (int32, int32_elt) t = [[1, 2],
                              [3, 4],
                              [5, 6]]
Sourceval broadcast_to : int array -> ('a, 'b) t -> ('a, 'b) t

broadcast_to shape t broadcasts tensor to target shape.

Shapes must be broadcast-compatible: dimensions align from right, each must be equal or source must be 1. Returns view (no copy) with zero strides for broadcast dimensions.

  # let t = create int32 [|1; 3|] [|1l; 2l; 3l|] in
    broadcast_to [|3; 3|] t
  - : (int32, int32_elt) t = [[1, 2, 3],
                              [1, 2, 3],
                              [1, 2, 3]]
  # let t = ones float32 [|3; 1|] in
    shape (broadcast_to [|2; 3; 4|] t)
  - : int array = [|2; 3; 4|]
Sourceval broadcasted : ?reverse:bool -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t

broadcasted ?reverse t1 t2 broadcasts tensors to common shape.

Returns views of both tensors broadcast to compatible shape. If reverse is true, returns (t2', t1') instead of (t1', t2'). Useful before element-wise operations.

  # let t1 = ones float32 [|3;1|] in
    let t2 = ones float32 [|1;5|] in
    let t1', t2' = broadcasted t1 t2 in
    shape t1', shape t2'
  - : int array * int array = ([|3; 5|], [|3; 5|])
Sourceval expand : int array -> ('a, 'b) t -> ('a, 'b) t

expand shape t broadcasts tensor where -1 keeps original dimension.

Like broadcast_to but -1 preserves existing dimension size. Adds dimensions on left if needed.

  # let t = ones float32 [|1; 4; 1|] in
    shape (expand [|3; -1; 5|] t)
  - : int array = [|3; 4; 5|]
  # let t = ones float32 [|5; 5|] in
    shape (expand [|-1; -1|] t)
  - : int array = [|5; 5|]
Sourceval flatten : ?start_dim:int -> ?end_dim:int -> ('a, 'b) t -> ('a, 'b) t

flatten ?start_dim ?end_dim t collapses dimensions into single dimension.

Default start_dim = 0, end_dim = -1 (last). Negative indices count from end. Dimensions start_dim through end_dim inclusive are flattened.

  # flatten (zeros float32 [| 2; 3; 4 |]) |> shape
  - : int array = [|24|]
  # flatten ~start_dim:1 ~end_dim:2 (zeros float32 [| 2; 3; 4; 5 |]) |> shape
  - : int array = [|2; 12; 5|]
Sourceval unflatten : int -> int array -> ('a, 'b) t -> ('a, 'b) t

unflatten dim sizes t expands dimension dim into multiple dimensions.

Product of sizes must equal size of dimension dim. At most one dimension can be -1 (inferred). Inverse of flatten.

  # unflatten 1 [| 3; 4 |] (zeros float32 [| 2; 12; 5 |]) |> shape
  - : int array = [|2; 3; 4; 5|]
  # unflatten 0 [| -1; 2 |] (ones float32 [| 6; 5 |]) |> shape
  - : int array = [|3; 2; 5|]
Sourceval ravel : ('a, 'b) t -> ('a, 'b) t

ravel t returns contiguous 1-D view.

Equivalent to flatten t but always returns contiguous result. Use when you need both flattening and contiguity.

  # let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
    ravel x
  - : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
  # let t = transpose (ones float32 [| 3; 4 |]) in
    is_c_contiguous t
  - : bool = false
  # let t_ravel = ravel t in
    is_c_contiguous t_ravel
  - : bool = true
Sourceval squeeze : ?axes:int array -> ('a, 'b) t -> ('a, 'b) t

squeeze ?axes t removes dimensions of size 1.

If axes specified, only removes those dimensions. Negative indices count from end. Returns view when possible.

  # squeeze (ones float32 [| 1; 3; 1; 4 |]) |> shape
  - : int array = [|3; 4|]
  # squeeze ~axes:[| 0; 2 |] (ones float32 [| 1; 3; 1; 4 |]) |> shape
  - : int array = [|3; 4|]
  # squeeze ~axes:[| -1 |] (ones float32 [| 3; 4; 1 |]) |> shape
  - : int array = [|3; 4|]
Sourceval unsqueeze : ?axes:int array -> ('a, 'b) t -> ('a, 'b) t

unsqueeze ?axes t inserts dimensions of size 1 at specified positions.

Axes refer to positions in result tensor. Must be in range 0, ndim.

  • raises Invalid_argument

    if axes not specified, out of bounds, or contains duplicates

  # unsqueeze ~axes:[| 0; 2 |] (create float32 [| 3 |] [| 1.; 2.; 3. |]) |> shape
  - : int array = [|1; 3; 1|]
  # unsqueeze ~axes:[| 1 |] (create float32 [| 2 |] [| 5.; 6. |]) |> shape
  - : int array = [|2; 1|]
Sourceval squeeze_axis : int -> ('a, 'b) t -> ('a, 'b) t

squeeze_axis axis t removes dimension axis if size is 1.

Sourceval unsqueeze_axis : int -> ('a, 'b) t -> ('a, 'b) t

unsqueeze_axis axis t inserts dimension of size 1 at axis.

Sourceval expand_dims : int array -> ('a, 'b) t -> ('a, 'b) t

expand_dims axes t is synonym for unsqueeze.

Sourceval transpose : ?axes:int array -> ('a, 'b) t -> ('a, 'b) t

transpose ?axes t permutes dimensions.

Default reverses all dimensions. axes must be permutation of 0..ndim-1. Returns view (no copy) with adjusted strides.

  # let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
    transpose x
  - : (int32, int32_elt) t = [[1, 4],
                              [2, 5],
                              [3, 6]]
  # transpose ~axes:[| 2; 0; 1 |] (zeros float32 [| 2; 3; 4 |]) |> shape
  - : int array = [|4; 2; 3|]
  # let id = transpose ~axes:[| 1; 0 |] in
    id == transpose
  - : bool = false
Sourceval flip : ?axes:int array -> ('a, 'b) t -> ('a, 'b) t

flip ?axes t reverses order along specified dimensions.

Default flips all dimensions.

  # let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
    flip x
  - : (int32, int32_elt) t = [[6, 5, 4],
                              [3, 2, 1]]
  # let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
    flip ~axes:[| 1 |] x
  - : (int32, int32_elt) t = [[3, 2, 1],
                              [6, 5, 4]]
Sourceval moveaxis : int -> int -> ('a, 'b) t -> ('a, 'b) t

moveaxis src dst t moves dimension from src to dst.

Sourceval swapaxes : int -> int -> ('a, 'b) t -> ('a, 'b) t

swapaxes axis1 axis2 t exchanges two dimensions.

Sourceval roll : ?axis:int -> int -> ('a, 'b) t -> ('a, 'b) t

roll ?axis shift t shifts elements along axis.

Elements shifted beyond last position wrap to beginning. If axis not specified, shifts flattened tensor. Negative shift rolls backward.

  # let x = create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |] in
    roll 2 x
  - : (int32, int32_elt) t = [4, 5, 1, 2, 3]
  # let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
    roll ~axis:1 1 x
  - : (int32, int32_elt) t = [[3, 1, 2],
                              [6, 4, 5]]
  # let x = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
    roll ~axis:0 (-1) x
  - : (int32, int32_elt) t = [[3, 4],
                              [1, 2]]
Sourceval pad : (int * int) array -> 'a -> ('a, 'b) t -> ('a, 'b) t

pad padding value t pads tensor with value.

padding specifies (before, after) for each dimension. Length must match tensor dimensions. Negative padding not allowed.

  # let x = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
    pad [| (1, 1); (2, 2) |] 0l x
  - : (int32, int32_elt) t =
  [[0, 0, 0, 0, 0, 0],
   [0, 0, 1, 2, 0, 0],
   [0, 0, 3, 4, 0, 0],
   [0, 0, 0, 0, 0, 0]]
Sourceval shrink : (int * int) array -> ('a, 'b) t -> ('a, 'b) t

shrink ranges t extracts slice from start to stop (exclusive) for each dimension.

  # let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
    shrink [| (1, 3); (0, 2) |] x
  - : (int32, int32_elt) t = [[4, 5],
                              [7, 8]]
Sourceval tile : int array -> ('a, 'b) t -> ('a, 'b) t

tile reps t constructs tensor by repeating t.

reps specifies repetitions per dimension. If longer than ndim, prepends dimensions. Zero repetitions create empty tensor.

  # let x = create int32 [| 1; 2 |] [| 1l; 2l |] in
    tile [| 2; 3 |] x
  - : (int32, int32_elt) t = [[1, 2, 1, 2, 1, 2],
                              [1, 2, 1, 2, 1, 2]]
  # let x = create int32 [| 2 |] [| 1l; 2l |] in
    tile [| 2; 1; 3 |] x |> shape
  - : int array = [|2; 1; 6|]
Sourceval repeat : ?axis:int -> int -> ('a, 'b) t -> ('a, 'b) t

repeat ?axis count t repeats elements count times.

If axis not specified, repeats flattened tensor.

  # let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    repeat 2 x
  - : (int32, int32_elt) t = [1, 1, 2, 2, 3, 3]
  # let x = create int32 [| 1; 2 |] [| 1l; 2l |] in
    repeat ~axis:0 3 x
  - : (int32, int32_elt) t = [[1, 2],
                              [1, 2],
                              [1, 2]]

Array Combination and Splitting

Functions to join and split arrays.

Sourceval concatenate : ?axis:int -> ('a, 'b) t list -> ('a, 'b) t

concatenate ?axis ts joins tensors along existing axis.

All tensors must have same shape except on concatenation axis. If axis not specified, flattens all tensors then concatenates. Returns contiguous result.

  # let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
    let x2 = create int32 [| 1; 2 |] [| 5l; 6l |] in
    concatenate ~axis:0 [x1; x2]
  - : (int32, int32_elt) t = [[1, 2],
                              [3, 4],
                              [5, 6]]
  # let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
    let x2 = create int32 [| 1; 2 |] [| 5l; 6l |] in
    concatenate [x1; x2]
  - : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
Sourceval stack : ?axis:int -> ('a, 'b) t list -> ('a, 'b) t

stack ?axis ts joins tensors along new axis.

All tensors must have identical shape. Result rank is input rank + 1. Default axis=0. Negative axis counts from end of result shape.

  # let x1 = create int32 [| 2 |] [| 1l; 2l |] in
    let x2 = create int32 [| 2 |] [| 3l; 4l |] in
    stack [x1; x2]
  - : (int32, int32_elt) t = [[1, 2],
                              [3, 4]]
  # let x1 = create int32 [| 2 |] [| 1l; 2l |] in
    let x2 = create int32 [| 2 |] [| 3l; 4l |] in
    stack ~axis:1 [x1; x2]
  - : (int32, int32_elt) t = [[1, 3],
                              [2, 4]]
  # stack ~axis:(-1) [ones float32 [| 2; 3 |]; zeros float32 [| 2; 3 |]] |> shape
  - : int array = [|2; 3; 2|]
Sourceval vstack : ('a, 'b) t list -> ('a, 'b) t

vstack ts stacks tensors vertically (row-wise).

1-D tensors are treated as row vectors (shape 1;n). Higher-D tensors concatenate along axis 0. All tensors must have same shape except possibly first dimension.

  # let x1 = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    let x2 = create int32 [| 3 |] [| 4l; 5l; 6l |] in
    vstack [x1; x2]
  - : (int32, int32_elt) t = [[1, 2, 3],
                              [4, 5, 6]]
  # let x1 = create int32 [| 1; 2 |] [| 1l; 2l |] in
    let x2 = create int32 [| 2; 2 |] [| 3l; 4l; 5l; 6l |] in
    vstack [x1; x2]
  - : (int32, int32_elt) t = [[1, 2],
                              [3, 4],
                              [5, 6]]
Sourceval hstack : ('a, 'b) t list -> ('a, 'b) t

hstack ts stacks tensors horizontally (column-wise).

1-D tensors concatenate directly. Higher-D tensors concatenate along axis 1. For 1-D arrays of different lengths, use vstack to make 2-D first.

  # let x1 = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    let x2 = create int32 [| 3 |] [| 4l; 5l; 6l |] in
    hstack [x1; x2]
  - : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
  # let x1 = create int32 [| 2; 1 |] [| 1l; 2l |] in
    let x2 = create int32 [| 2; 1 |] [| 3l; 4l |] in
    hstack [x1; x2]
  - : (int32, int32_elt) t = [[1, 3],
                              [2, 4]]
  # let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
    let x2 = create int32 [| 2; 1 |] [| 5l; 6l |] in
    hstack [x1; x2]
  - : (int32, int32_elt) t = [[1, 2, 5],
                              [3, 4, 6]]
Sourceval dstack : ('a, 'b) t list -> ('a, 'b) t

dstack ts stacks tensors depth-wise (along third axis).

Tensors are reshaped to at least 3-D before concatenation:

  • 1-D shape n1;n;1
  • 2-D shape m;nm;n;1
  • 3-D+ unchanged
  # let x1 = create int32 [| 2 |] [| 1l; 2l |] in
    let x2 = create int32 [| 2 |] [| 3l; 4l |] in
    dstack [x1; x2]
  - : (int32, int32_elt) t = [[[1, 3],
                               [2, 4]]]
  # let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
    let x2 = create int32 [| 2; 2 |] [| 5l; 6l; 7l; 8l |] in
    dstack [x1; x2]
  - : (int32, int32_elt) t = [[[1, 5],
                               [2, 6]],
                              [[3, 7],
                               [4, 8]]]
Sourceval broadcast_arrays : ('a, 'b) t list -> ('a, 'b) t list

broadcast_arrays ts broadcasts all tensors to common shape.

Finds the common broadcast shape and returns list of views with that shape. Broadcasting rules: dimensions align right, each must be 1 or equal.

  # let x1 = ones float32 [| 3; 1 |] in
    let x2 = ones float32 [| 1; 5 |] in
    broadcast_arrays [x1; x2] |> List.map shape
  - : int array list = [[|3; 5|]; [|3; 5|]]
  # let x1 = scalar float32 5. in
    let x2 = ones float32 [| 2; 3; 4 |] in
    broadcast_arrays [x1; x2] |> List.map shape
  - : int array list = [[|2; 3; 4|]; [|2; 3; 4|]]
Sourceval array_split : axis:int -> [< `Count of int | `Indices of int list ] -> ('a, 'b) t -> ('a, 'b) t list

array_split ~axis sections t splits tensor into multiple parts.

`Count n divides into n parts as evenly as possible. Extra elements go to first parts. `Indices [i1;i2;...] splits at indices creating start:i1, i1:i2, i2:end.

  # let x = create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |] in
    array_split ~axis:0 (`Count 3) x
  - : (int32, int32_elt) t list = [[1, 2]; [3, 4]; [5]]
  # let x = create int32 [| 6 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
    array_split ~axis:0 (`Indices [ 2; 4 ]) x
  - : (int32, int32_elt) t list = [[1, 2]; [3, 4]; [5, 6]]
Sourceval split : axis:int -> int -> ('a, 'b) t -> ('a, 'b) t list

split ~axis sections t splits into equal parts.

  # let x = create int32 [| 4; 2 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l |] in
    split ~axis:0 2 x
  - : (int32, int32_elt) t list = [[[1, 2],
                                    [3, 4]]; [[5, 6],
                                              [7, 8]]]

Type Conversion and Copying

Functions to convert between types and create copies.

Sourceval cast : ('c, 'd) dtype -> ('a, 'b) t -> ('c, 'd) t

cast dtype t converts elements to new dtype.

Returns copy with same values in new type.

  # let x = create float32 [| 3 |] [| 1.5; 2.7; 3.1 |] in
    cast int32 x
  - : (int32, int32_elt) t = [1, 2, 3]
Sourceval astype : ('a, 'b) dtype -> ('c, 'd) t -> ('a, 'b) t

astype dtype t is synonym for cast.

Sourceval contiguous : ('a, 'b) t -> ('a, 'b) t

contiguous t returns C-contiguous tensor.

Returns t unchanged if already contiguous (O(1)), otherwise creates contiguous copy (O(n)). Use before operations requiring direct memory access.

  # let t = transpose (ones float32 [| 3; 4 |]) in
    is_c_contiguous (contiguous t)
  - : bool = true
Sourceval copy : ('a, 'b) t -> ('a, 'b) t

copy t returns deep copy.

Always allocates new memory and copies data. Result is contiguous.

  # let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
    let y = copy x in
    set_item [ 0 ] 999. y;
    x, y
  - : (float, float32_elt) t * (float, float32_elt) t =
  ([1, 2, 3], [999, 2, 3])
Sourceval blit : ('a, 'b) t -> ('a, 'b) t -> unit

blit src dst copies src into dst.

Shapes must match exactly. Handles broadcasting internally. Modifies dst in-place.

  let dst = zeros float32 [| 3; 3 |] in
  blit (ones float32 [| 3; 3 |]) dst
  (* dst now contains all 1s *)
Sourceval fill : 'a -> ('a, 'b) t -> ('a, 'b) t

fill value t sets all elements to value.

Modifies t in-place and returns it for chaining.

  # let x = zeros float32 [| 2; 3 |] in
    let y = fill 5. x in
    y == x
  - : bool = true

Element Access and Slicing

Functions to access and modify array elements.

Sourceval slice : index list -> ('a, 'b) t -> ('a, 'b) t

slice indices t extracts subtensor.

  • I n: select index n (reduces dimension)
  • L [i;j;k]: fancy indexing - select indices i, j, k
  • R [start;stop;step]: range [start, stop) with step

Stop is exclusive. Negative indices count from end. Missing indices select all. Returns view when possible.

  # let x = create int32 [| 2; 4 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l |] in
    slice [ I 1 ] x
  - : (int32, int32_elt) t = [5, 6, 7, 8]
  # let x = create int32 [| 5 |] [| 0l; 1l; 2l; 3l; 4l |] in
    slice [ R [ 1; 3 ] ] x
  - : (int32, int32_elt) t = [1, 2]
Sourceval set_slice : index list -> ('a, 'b) t -> ('a, 'b) t -> unit

set_slice indices t value assigns value to slice.

Sourceval slice_ranges : ?steps:int list -> int list -> int list -> ('a, 'b) t -> ('a, 'b) t

slice_ranges ?steps starts stops t extracts ranges.

Equivalent to slice [R[s0;e0;st0]; R[s1;e1;st1]; ...] t. Lists must have same length ≤ ndim. Default step is 1. Missing dimensions select all.

  # let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
    slice_ranges [ 0; 1 ] [ 2; 3 ] x
  - : (int32, int32_elt) t = [[2, 3],
                              [5, 6]]
  # slice_ranges ~steps:[ 2; 1 ] [ 0; 0 ] [ 4; 2 ] (eye int32 4)
  - : (int32, int32_elt) t = [[1, 0],
                              [0, 0]]
Sourceval set_slice_ranges : ?steps:int list -> int list -> int list -> ('a, 'b) t -> ('a, 'b) t -> unit

set_slice_ranges ?steps starts stops t value assigns to ranges.

Like slice_ranges but assigns value to selected region. Value is broadcast to target shape if needed.

  # let x = zeros float32 [| 3; 3 |] in
    set_slice_ranges [ 1; 2 ] [ 2; 3 ] x (ones float32 [| 1; 1 |]);
    get_item [ 1; 2 ] x
  - : float = 1.
Sourceval get : int list -> ('a, 'b) t -> ('a, 'b) t

get indices t returns subtensor at indices.

Indexes from outermost dimension. Returns scalar tensor if all dimensions indexed, otherwise returns view of remaining dimensions.

  # let x = create int32 [| 2; 2; 2 |] [| 0l; 1l; 2l; 3l; 4l; 5l; 6l; 7l |] in
    get [ 1; 1; 1 ] x
  - : (int32, int32_elt) t = 7
  # let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
    get [ 1 ] x
  - : (int32, int32_elt) t = [4, 5, 6]
Sourceval set : int list -> ('a, 'b) t -> ('a, 'b) t -> unit

set indices t value assigns value at indices.

Sourceval get_item : int list -> ('a, 'b) t -> 'a

get_item indices t returns scalar value at indices.

Must provide indices for all dimensions.

Sourceval set_item : int list -> 'a -> ('a, 'b) t -> unit

set_item indices value t sets scalar value at indices.

Must provide indices for all dimensions. Modifies tensor in-place.

Basic Arithmetic Operations

Element-wise arithmetic operations and their variants.

Sourceval add : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

add t1 t2 computes element-wise sum with broadcasting.

Sourceval add_s : ('a, 'b) t -> 'a -> ('a, 'b) t

add_s t scalar adds scalar to each element.

Sourceval iadd : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

iadd target value adds value to target in-place.

Returns modified target.

Sourceval radd_s : 'a -> ('a, 'b) t -> ('a, 'b) t

radd_s scalar t is add_s t scalar.

Sourceval iadd_s : ('a, 'b) t -> 'a -> ('a, 'b) t

iadd_s t scalar adds scalar to t in-place.

Sourceval sub : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

sub t1 t2 computes element-wise difference with broadcasting.

Sourceval sub_s : ('a, 'b) t -> 'a -> ('a, 'b) t

sub_s t scalar subtracts scalar from each element.

Sourceval rsub_s : 'a -> ('a, 'b) t -> ('a, 'b) t

rsub_s scalar t computes scalar - t.

Sourceval isub : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

isub target value subtracts value from target in-place.

Sourceval isub_s : ('a, 'b) t -> 'a -> ('a, 'b) t

isub_s t scalar subtracts scalar from t in-place.

Sourceval mul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

mul t1 t2 computes element-wise product with broadcasting.

Sourceval mul_s : ('a, 'b) t -> 'a -> ('a, 'b) t

mul_s t scalar multiplies each element by scalar.

Sourceval rmul_s : 'a -> ('a, 'b) t -> ('a, 'b) t

rmul_s scalar t is mul_s t scalar.

Sourceval imul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

imul target value multiplies target by value in-place.

Sourceval imul_s : ('a, 'b) t -> 'a -> ('a, 'b) t

imul_s t scalar multiplies t by scalar in-place.

Sourceval div : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

div t1 t2 computes element-wise division.

True division for floats (result is float). Integer division for integers (truncates toward zero). Complex division follows standard rules.

  # let x = create float32 [| 3 |] [| 7.; 8.; 9. |] in
    let y = create float32 [| 3 |] [| 2.; 2.; 2. |] in
    div x y
  - : (float, float32_elt) t = [3.5, 4, 4.5]
  # let x = create int32 [| 3 |] [| 7l; 8l; 9l |] in
    let y = create int32 [| 3 |] [| 2l; 2l; 2l |] in
    div x y
  - : (int32, int32_elt) t = [3, 4, 4]
  # let x = create int32 [| 2 |] [| -7l; 8l |] in
    let y = create int32 [| 2 |] [| 2l; 2l |] in
    div x y
  - : (int32, int32_elt) t = [-3, 4]
Sourceval div_s : ('a, 'b) t -> 'a -> ('a, 'b) t

div_s t scalar divides each element by scalar.

Sourceval rdiv_s : 'a -> ('a, 'b) t -> ('a, 'b) t

rdiv_s scalar t computes scalar / t.

Sourceval idiv : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

idiv target value divides target by value in-place.

Sourceval idiv_s : ('a, 'b) t -> 'a -> ('a, 'b) t

idiv_s t scalar divides t by scalar in-place.

Sourceval pow : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

pow base exponent computes element-wise power.

Sourceval pow_s : ('a, 'b) t -> 'a -> ('a, 'b) t

pow_s t scalar raises each element to scalar power.

Sourceval rpow_s : 'a -> ('a, 'b) t -> ('a, 'b) t

rpow_s scalar t computes scalar ** t.

Sourceval ipow : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

ipow target exponent raises target to exponent in-place.

Sourceval ipow_s : ('a, 'b) t -> 'a -> ('a, 'b) t

ipow_s t scalar raises t to scalar power in-place.

Sourceval mod_ : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

mod_ t1 t2 computes element-wise modulo.

Sourceval mod_s : ('a, 'b) t -> 'a -> ('a, 'b) t

mod_s t scalar computes modulo scalar for each element.

Sourceval rmod_s : 'a -> ('a, 'b) t -> ('a, 'b) t

rmod_s scalar t computes scalar mod t.

Sourceval imod : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

imod target divisor computes modulo in-place.

Sourceval imod_s : ('a, 'b) t -> 'a -> ('a, 'b) t

imod_s t scalar computes modulo scalar in-place.

Sourceval neg : ('a, 'b) t -> ('a, 'b) t

neg t negates all elements.

Mathematical Functions

Unary mathematical operations and special functions.

Sourceval abs : ('a, 'b) t -> ('a, 'b) t

abs t computes absolute value.

Sourceval sign : ('a, 'b) t -> ('a, 'b) t

sign t returns -1, 0, or 1 based on sign.

For unsigned types, returns 1 for all non-zero values, 0 for zero.

  # let x = create float32 [| 3 |] [| -2.; 0.; 3.5 |] in
    sign x
  - : (float, float32_elt) t = [-1, 0, 1]
Sourceval square : ('a, 'b) t -> ('a, 'b) t

square t computes element-wise square.

Sourceval sqrt : ('a, 'b) t -> ('a, 'b) t

sqrt t computes element-wise square root.

Sourceval rsqrt : ('a, 'b) t -> ('a, 'b) t

rsqrt t computes reciprocal square root.

Sourceval recip : ('a, 'b) t -> ('a, 'b) t

recip t computes element-wise reciprocal.

Sourceval log : (float, 'a) t -> (float, 'a) t

log t computes natural logarithm.

Sourceval log2 : ('a, 'b) t -> ('a, 'b) t

log2 t computes base-2 logarithm.

Sourceval exp : (float, 'a) t -> (float, 'a) t

exp t computes exponential.

Sourceval exp2 : ('a, 'b) t -> ('a, 'b) t

exp2 t computes 2^x.

Sourceval sin : ('a, 'b) t -> ('a, 'b) t

sin t computes sine.

Sourceval cos : (float, 'a) t -> (float, 'a) t

cos t computes cosine.

Sourceval tan : (float, 'a) t -> (float, 'a) t

tan t computes tangent.

Sourceval asin : (float, 'a) t -> (float, 'a) t

asin t computes arcsine.

Sourceval acos : (float, 'a) t -> (float, 'a) t

acos t computes arccosine.

Sourceval atan : (float, 'a) t -> (float, 'a) t

atan t computes arctangent.

Sourceval atan2 : (float, 'a) t -> (float, 'a) t -> (float, 'a) t

atan2 y x computes arctangent of y/x using signs to determine quadrant.

Returns angle in radians in range -π, π. Handles x=0 correctly.

  # let y = scalar float32 1. in
    let x = scalar float32 1. in
    atan2 y x |> get_item [] |> Float.round
  - : float = 1.
  # let y = scalar float32 1. in
    let x = scalar float32 0. in
    atan2 y x |> get_item [] |> Float.round
  - : float = 2.
  # let y = scalar float32 0. in
    let x = scalar float32 0. in
    atan2 y x |> get_item []
  - : float = 0.
Sourceval sinh : (float, 'a) t -> (float, 'a) t

sinh t computes hyperbolic sine.

Sourceval cosh : (float, 'a) t -> (float, 'a) t

cosh t computes hyperbolic cosine.

Sourceval tanh : (float, 'a) t -> (float, 'a) t

tanh t computes hyperbolic tangent.

Sourceval asinh : (float, 'a) t -> (float, 'a) t

asinh t computes inverse hyperbolic sine.

Sourceval acosh : (float, 'a) t -> (float, 'a) t

acosh t computes inverse hyperbolic cosine.

Sourceval atanh : (float, 'a) t -> (float, 'a) t

atanh t computes inverse hyperbolic tangent.

Sourceval hypot : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

hypot x y computes sqrt(x² + y²) avoiding overflow.

Uses numerically stable algorithm: max * sqrt(1 + (min/max)²).

  # let x = scalar float32 3. in
    let y = scalar float32 4. in
    hypot x y |> get_item []
  - : float = 5.
  # let x = scalar float64 1e200 in
    let y = scalar float64 1e200 in
    hypot x y |> get_item [] < Float.infinity
  - : bool = true
Sourceval trunc : ('a, 'b) t -> ('a, 'b) t

trunc t rounds toward zero.

Removes fractional part. Positive values round down, negative round up.

  # let x = create float32 [| 3 |] [| 2.7; -2.7; 2.0 |] in
    trunc x
  - : (float, float32_elt) t = [2, -2, 2]
Sourceval ceil : (float, 'a) t -> (float, 'a) t

ceil t rounds up to nearest integer.

Smallest integer not less than input.

  # let x = create float32 [| 4 |] [| 2.1; 2.9; -2.1; -2.9 |] in
    ceil x
  - : (float, float32_elt) t = [3, 3, -2, -2]
Sourceval floor : (float, 'a) t -> (float, 'a) t

floor t rounds down to nearest integer.

Largest integer not greater than input.

  # let x = create float32 [| 4 |] [| 2.1; 2.9; -2.1; -2.9 |] in
    floor x
  - : (float, float32_elt) t = [2, 2, -3, -3]
Sourceval round : (float, 'a) t -> (float, 'a) t

round t rounds to nearest integer (half away from zero).

Ties round away from zero (not banker's rounding).

  # let x = create float32 [| 4 |] [| 2.5; 3.5; -2.5; -3.5 |] in
    round x
  - : (float, float32_elt) t = [3, 4, -3, -4]
Sourceval lerp : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

lerp start end_ weight computes linear interpolation.

Returns start + weight * (end_ - start). Weight typically in 0, 1.

  # let start = scalar float32 0. in
    let end_ = scalar float32 10. in
    let weight = scalar float32 0.3 in
    lerp start end_ weight |> get_item []
  - : float = 3.
  # let start = create float32 [| 2 |] [| 1.; 2. |] in
    let end_ = create float32 [| 2 |] [| 5.; 8. |] in
    let weight = create float32 [| 2 |] [| 0.25; 0.5 |] in
    lerp start end_ weight
  - : (float, float32_elt) t = [2, 5]
Sourceval lerp_scalar_weight : ('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t

lerp_scalar_weight start end_ weight interpolates with scalar weight.

Comparison and Logical Operations

Element-wise comparisons and logical operations.

Sourceval cmplt : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

cmplt t1 t2 returns 1 where t1 < t2, 0 elsewhere.

Sourceval less : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

less t1 t2 is synonym for cmplt.

Sourceval cmpne : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

cmpne t1 t2 returns 1 where t1 ≠ t2, 0 elsewhere.

Sourceval not_equal : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

not_equal t1 t2 is synonym for cmpne.

Sourceval cmpeq : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

cmpeq t1 t2 returns 1 where t1 = t2, 0 elsewhere.

Sourceval equal : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

equal t1 t2 is synonym for cmpeq.

Sourceval cmpgt : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

cmpgt t1 t2 returns 1 where t1 > t2, 0 elsewhere.

Sourceval greater : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

greater t1 t2 is synonym for cmpgt.

Sourceval cmple : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

cmple t1 t2 returns 1 where t1 ≤ t2, 0 elsewhere.

Sourceval less_equal : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

less_equal t1 t2 is synonym for cmple.

Sourceval cmpge : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

cmpge t1 t2 returns 1 where t1 ≥ t2, 0 elsewhere.

Sourceval greater_equal : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

greater_equal t1 t2 is synonym for cmpge.

Sourceval array_equal : ('a, 'b) t -> ('a, 'b) t -> (int, uint8_elt) t

array_equal t1 t2 returns scalar 1 if all elements equal, 0 otherwise.

Broadcasts inputs before comparison. Returns 0 if shapes incompatible.

  # let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    let y = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    array_equal x y |> get_item []
  - : int = 1
  # let x = create int32 [| 2 |] [| 1l; 2l |] in
    let y = create int32 [| 2 |] [| 1l; 3l |] in
    array_equal x y |> get_item []
  - : int = 0
Sourceval maximum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

maximum t1 t2 returns element-wise maximum.

Sourceval maximum_s : ('a, 'b) t -> 'a -> ('a, 'b) t

maximum_s t scalar returns maximum of each element and scalar.

Sourceval rmaximum_s : 'a -> ('a, 'b) t -> ('a, 'b) t

rmaximum_s scalar t is maximum_s t scalar.

Sourceval imaximum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

imaximum target value computes maximum in-place.

Sourceval imaximum_s : ('a, 'b) t -> 'a -> ('a, 'b) t

imaximum_s t scalar computes maximum with scalar in-place.

Sourceval minimum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

minimum t1 t2 returns element-wise minimum.

Sourceval minimum_s : ('a, 'b) t -> 'a -> ('a, 'b) t

minimum_s t scalar returns minimum of each element and scalar.

Sourceval rminimum_s : 'a -> ('a, 'b) t -> ('a, 'b) t

rminimum_s scalar t is minimum_s t scalar.

Sourceval iminimum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

iminimum target value computes minimum in-place.

Sourceval iminimum_s : ('a, 'b) t -> 'a -> ('a, 'b) t

iminimum_s t scalar computes minimum with scalar in-place.

Sourceval logical_and : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

logical_and t1 t2 computes element-wise AND.

Non-zero values are true.

Sourceval logical_or : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

logical_or t1 t2 computes element-wise OR.

Sourceval logical_xor : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

logical_xor t1 t2 computes element-wise XOR.

Sourceval logical_not : ('a, 'b) t -> ('a, 'b) t

logical_not t computes element-wise NOT.

Returns 1 - x. Non-zero values become 0, zero becomes 1.

  # let x = create int32 [| 3 |] [| 0l; 1l; 5l |] in
    logical_not x
  - : (int32, int32_elt) t = [1, 0, -4]
Sourceval isinf : (float, 'a) t -> (int, uint8_elt) t

isinf t returns 1 where infinite, 0 elsewhere.

Detects both positive and negative infinity. Non-float types return all 0s.

  # let x = create float32 [| 4 |] [| 1.; Float.infinity; Float.neg_infinity; Float.nan |] in
    isinf x
  - : (int, uint8_elt) t = [0, 1, 1, 0]
Sourceval isnan : ('a, 'b) t -> (int, uint8_elt) t

isnan t returns 1 where NaN, 0 elsewhere.

NaN is the only value that doesn't equal itself. Non-float types return all 0s.

  # let x = create float32 [| 3 |] [| 1.; Float.nan; Float.infinity |] in
    isnan x
  - : (int, uint8_elt) t = [0, 1, 0]
Sourceval isfinite : (float, 'a) t -> (int, uint8_elt) t

isfinite t returns 1 where finite, 0 elsewhere.

Finite means not inf, -inf, or NaN. Non-float types return all 1s.

  # let x = create float32 [| 4 |] [| 1.; Float.infinity; Float.nan; -0. |] in
    isfinite x
  - : (int, uint8_elt) t = [1, 0, 0, 1]
Sourceval where : (int, uint8_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

where cond if_true if_false selects elements based on condition.

Returns if_true where cond is non-zero, if_false elsewhere. All three inputs broadcast to common shape.

  # let cond = create uint8 [| 3 |] [| 1; 0; 1 |] in
    let if_true = create int32 [| 3 |] [| 2l; 3l; 4l |] in
    let if_false = create int32 [| 3 |] [| 5l; 6l; 7l |] in
    where cond if_true if_false
  - : (int32, int32_elt) t = [2, 6, 4]
  # let x = create float32 [| 4 |] [| -1.; 2.; -3.; 4. |] in
    where (cmpgt x (scalar float32 0.)) x (scalar float32 0.)
  - : (float, float32_elt) t = [0, 2, 0, 4]
Sourceval clamp : ?min:'a -> ?max:'a -> ('a, 'b) t -> ('a, 'b) t

clamp ?min ?max t limits values to range.

Elements below min become min, above max become max.

Sourceval clip : ?min:'a -> ?max:'a -> ('a, 'b) t -> ('a, 'b) t

clip ?min ?max t is synonym for clamp.

Bitwise Operations

Bitwise operations on integer arrays.

Sourceval bitwise_xor : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

bitwise_xor t1 t2 computes element-wise XOR.

Sourceval bitwise_or : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

bitwise_or t1 t2 computes element-wise OR.

Sourceval bitwise_and : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

bitwise_and t1 t2 computes element-wise AND.

Sourceval bitwise_not : ('a, 'b) t -> ('a, 'b) t

bitwise_not t computes element-wise NOT.

Sourceval invert : ('a, 'b) t -> ('a, 'b) t

invert t is synonym for bitwise_not.

Sourceval lshift : ('a, 'b) t -> int -> ('a, 'b) t

lshift t shift left-shifts elements by shift bits.

Equivalent to multiplication by 2^shift. Overflow wraps around.

  # let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    lshift x 2
  - : (int32, int32_elt) t = [4, 8, 12]
Sourceval rshift : ('a, 'b) t -> int -> ('a, 'b) t

rshift t shift right-shifts elements by shift bits.

Equivalent to integer division by 2^shift (rounds toward zero).

  # let x = create int32 [| 3 |] [| 8l; 9l; 10l |] in
    rshift x 2
  - : (int32, int32_elt) t = [2, 2, 2]

Reduction Operations

Functions that reduce array dimensions.

Sourceval sum : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

sum ?axes ?keepdims t sums elements along specified axes.

Default sums all axes (returns scalar). If keepdims is true, retains reduced dimensions with size 1. Negative axes count from end.

  # let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
    sum x |> get_item []
  - : float = 10.
  # let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
    sum ~axes:[| 0 |] x
  - : (float, float32_elt) t = [4, 6]
  # let x = create float32 [| 1; 2 |] [| 1.; 2. |] in
    sum ~axes:[| 1 |] ~keepdims:true x
  - : (float, float32_elt) t = [[3]]
  # let x = create float32 [| 1; 3 |] [| 1.; 2.; 3. |] in
    sum ~axes:[| -1 |] x
  - : (float, float32_elt) t = [6]
Sourceval max : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

max ?axes ?keepdims t finds maximum along axes.

Default reduces all axes. NaN propagates (any NaN input gives NaN output).

  # let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
    max x |> get_item []
  - : float = 6.
  # let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
    max ~axes:[| 0 |] x
  - : (float, float32_elt) t = [3, 4]
  # let x = create float32 [| 1; 2 |] [| 1.; 2. |] in
    max ~axes:[| 1 |] ~keepdims:true x
  - : (float, float32_elt) t = [[2]]
Sourceval min : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

min ?axes ?keepdims t finds minimum along axes.

Default reduces all axes. NaN propagates (any NaN input gives NaN output).

  # let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
    min x |> get_item []
  - : float = 1.
  # let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
    min ~axes:[| 0 |] x
  - : (float, float32_elt) t = [1, 2]
Sourceval prod : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

prod ?axes ?keepdims t computes product along axes.

Default multiplies all elements. Empty axes give 1.

  # let x = create int32 [| 3 |] [| 2l; 3l; 4l |] in
    prod x |> get_item []
  - : int32 = 24l
  # let x = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
    prod ~axes:[| 0 |] x
  - : (int32, int32_elt) t = [3, 8]
Sourceval mean : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

mean ?axes ?keepdims t computes arithmetic mean along axes.

Sum of elements divided by count. NaN propagates.

  # let x = create float32 [| 4 |] [| 1.; 2.; 3.; 4. |] in
    mean x |> get_item []
  - : float = 2.5
  # let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
    mean ~axes:[| 1 |] x
  - : (float, float32_elt) t = [2, 5]
Sourceval var : ?axes:int array -> ?keepdims:bool -> ?ddof:int -> ('a, 'b) t -> ('a, 'b) t

var ?axes ?keepdims ?ddof t computes variance along axes.

ddof is delta degrees of freedom. Default 0 (population variance). Use 1 for sample variance. Variance = E(X - E[X])² / (N - ddof).

  # let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
    var x |> get_item []
  - : float = 2.
  # let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
    var ~ddof:1 x |> get_item []
  - : float = 2.5
Sourceval std : ?axes:int array -> ?keepdims:bool -> ?ddof:int -> ('a, 'b) t -> ('a, 'b) t

std ?axes ?keepdims ?ddof t computes standard deviation.

Square root of variance: sqrt(var(t, ddof)). See var for ddof meaning.

  # let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
    std x |> get_item [] |> Float.round
  - : float = 1.
  # let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
    std ~ddof:1 x |> get_item [] |> Float.round
  - : float = 2.
Sourceval all : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> (int, uint8_elt) t

all ?axes ?keepdims t tests if all elements are true (non-zero).

Returns 1 if all elements along axes are non-zero, 0 otherwise.

  # let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    all x |> get_item []
  - : int = 1
  # let x = create int32 [| 3 |] [| 1l; 0l; 3l |] in
    all x |> get_item []
  - : int = 0
  # let x = create int32 [| 2; 2 |] [| 1l; 0l; 1l; 1l |] in
    all ~axes:[| 1 |] x
  - : (int, uint8_elt) t = [0, 1]
Sourceval any : ?axes:int array -> ?keepdims:bool -> ('a, 'b) t -> (int, uint8_elt) t

any ?axes ?keepdims t tests if any element is true (non-zero).

Returns 1 if any element along axes is non-zero, 0 if all are zero.

  # let x = create int32 [| 3 |] [| 0l; 0l; 1l |] in
    any x |> get_item []
  - : int = 1
  # let x = create int32 [| 3 |] [| 0l; 0l; 0l |] in
    any x |> get_item []
  - : int = 0
  # let x = create int32 [| 2; 2 |] [| 0l; 0l; 0l; 1l |] in
    any ~axes:[| 1 |] x
  - : (int, uint8_elt) t = [0, 1]
Sourceval argmax : ?axis:int -> ?keepdims:bool -> ('a, 'b) t -> (int32, int32_elt) t

argmax ?axis ?keepdims t finds indices of maximum values.

Returns index of first occurrence for ties. If axis not specified, operates on flattened tensor and returns scalar.

  # let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
    argmax x |> get_item []
  - : int32 = 4l
  # let x = create int32 [| 2; 3 |] [| 1l; 5l; 3l; 2l; 4l; 6l |] in
    argmax ~axis:1 x
  - : (int32, int32_elt) t = [1, 2]
Sourceval argmin : ?axis:int -> ?keepdims:bool -> ('a, 'b) t -> (int32, int32_elt) t

argmin ?axis ?keepdims t finds indices of minimum values.

Returns index of first occurrence for ties. If axis not specified, operates on flattened tensor and returns scalar.

  # let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
    argmin x |> get_item []
  - : int32 = 1l
  # let x = create int32 [| 2; 3 |] [| 5l; 2l; 3l; 1l; 4l; 0l |] in
    argmin ~axis:1 x
  - : (int32, int32_elt) t = [1, 2]

Sorting and Searching

Functions for sorting arrays and finding indices.

Sourceval sort : ?descending:bool -> ?axis:int -> ('a, 'b) t -> ('a, 'b) t * (int32, int32_elt) t

sort ?descending ?axis t sorts elements along axis.

Returns (sorted_values, indices) where indices map sorted positions to original positions. Default sorts last axis in ascending order.

Algorithm: Bitonic sort (parallel-friendly, stable)

  • Pads to power of 2 with inf/-inf for correctness
  • O(n log² n) comparisons, O(log² n) depth
  • Stable: preserves relative order of equal elements
  • First occurrence wins for duplicate values

Special values:

  • NaN: sorted to end (ascending) or beginning (descending)
  • inf/-inf: sorted normally
  • For integers: uses max/min values for padding
  # let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
    sort x
  - : (int32, int32_elt) t * (int32, int32_elt) t =
  ([1, 1, 3, 4, 5], [1, 3, 0, 2, 4])
  # let x = create int32 [| 2; 2 |] [| 3l; 1l; 1l; 4l |] in
    sort ~descending:true ~axis:0 x
  - : (int32, int32_elt) t * (int32, int32_elt) t =
  ([[3, 4],
    [1, 1]], [[0, 1],
              [1, 0]])
  # let x = create float32 [| 4 |] [| Float.nan; 1.; 2.; Float.nan |] in
    let v, _ = sort x in
    v
  - : (float, float32_elt) t = [1, 2, nan, nan]
Sourceval argsort : ?descending:bool -> ?axis:int -> ('a, 'b) t -> (int32, int32_elt) t

argsort ?descending ?axis t returns indices that would sort tensor.

Equivalent to snd (sort ?descending ?axis t). Returns indices such that taking elements at these indices yields sorted array.

For 1-D: resulti is the index of the i-th smallest element. For N-D: sorts along specified axis independently.

  # let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
    argsort x
  - : (int32, int32_elt) t = [1, 3, 0, 2, 4]
  # let x = create int32 [| 2; 3 |] [| 3l; 1l; 4l; 2l; 5l; 0l |] in
    argsort ~axis:1 x
  - : (int32, int32_elt) t = [[1, 0, 2],
                              [2, 0, 1]]

Linear Algebra

Matrix operations and linear algebra functions.

Sourceval dot : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

dot a b computes generalized dot product.

Contracts last axis of a with:

  • 1-D b: the only axis (axis 0)
  • N-D b: second-to-last axis (axis -2)

Dimension rules:

  • 1-D × 1-D: inner product, returns scalar
  • 2-D × 2-D: matrix multiplication
  • N-D × M-D: batched contraction over all but contracted axes

Supports broadcasting on batch dimensions. Result shape is concatenation of:

  • Broadcasted batch dims
  • Remaining dims from a (except last)
  • Remaining dims from b (except contracted axis)
  • raises Invalid_argument

    if contraction axes have different sizes or inputs are 0-D

  # let a = create float32 [| 2 |] [| 1.; 2. |] in
    let b = create float32 [| 2 |] [| 3.; 4. |] in
    dot a b |> get_item []
  - : float = 11.
  # let a = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
    let b = create float32 [| 2; 2 |] [| 5.; 6.; 7.; 8. |] in
    dot a b
  - : (float, float32_elt) t = [[19, 22],
                                [43, 50]]
  # dot (ones float32 [| 3; 4; 5 |]) (ones float32 [| 5; 6 |]) |> shape
  - : int array = [|3; 4; 6|]
  # dot (ones float32 [| 2; 3; 4; 5 |]) (ones float32 [| 3; 5; 6 |]) |> shape
  - : int array = [|2; 3; 4; 6|]
Sourceval matmul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

matmul a b computes matrix multiplication with broadcasting.

Follows NumPy's @ operator semantics:

  • 1-D × 1-D: inner product (returns scalar tensor)
  • 1-D × N-D: treated as 1 × k @ ... × k × n... × n
  • N-D × 1-D: treated as ... × m × k @ k × 1... × m
  • N-D × M-D: batched matrix multiply on last 2 dimensions

Broadcasting rules:

  • All dimensions except last 2 are broadcast together
  • For 1-D inputs, dimension is temporarily added then removed
  • Inner dimensions must match: a.shape-1 == b.shape-2

Result shape:

  • Batch dims: broadcast(a.shape:-2, b.shape:-2)
  • Matrix dims: ..., a.shape[-2], b.shape[-1]
  • 1-D adjustments applied after
  # let a = create float32 [| 3 |] [| 1.; 2.; 3. |] in
    let b = create float32 [| 3 |] [| 4.; 5.; 6. |] in
    matmul a b |> get_item []
  - : float = 32.
  # let a = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
    let b = create float32 [| 2 |] [| 5.; 6. |] in
    matmul a b
  - : (float, float32_elt) t = [17, 39]
  # let a = create float32 [| 2 |] [| 1.; 2. |] in
    let b = create float32 [| 2; 3 |] [| 3.; 4.; 5.; 6.; 7.; 8. |] in
    matmul a b
  - : (float, float32_elt) t = [15, 18, 21]
  # matmul (ones float32 [| 10; 3; 4 |]) (ones float32 [| 10; 4; 5 |]) |> shape
  - : int array = [|10; 3; 5|]
  # matmul (ones float32 [| 1; 3; 4 |]) (ones float32 [| 5; 4; 2 |]) |> shape
  - : int array = [|5; 3; 2|]

Activation Functions

Neural network activation functions.

Sourceval relu : ('a, 'b) t -> ('a, 'b) t

relu t applies Rectified Linear Unit: max(0, x).

  # let x = create float32 [| 5 |] [| -2.; -1.; 0.; 1.; 2. |] in
    relu x
  - : (float, float32_elt) t = [0, 0, 0, 1, 2]
Sourceval relu6 : (float, 'a) t -> (float, 'a) t

relu6 t applies ReLU6: min(max(0, x), 6).

Bounded ReLU used in mobile networks. Clips to 0, 6 range.

  # let x = create float32 [| 3 |] [| -1.; 3.; 8. |] in
    relu6 x
  - : (float, float32_elt) t = [0, 3, 6]
Sourceval sigmoid : (float, 'a) t -> (float, 'a) t

sigmoid t applies logistic sigmoid: 1 / (1 + exp(-x)).

Output in range (0, 1). Symmetric around x=0 where sigmoid(0) = 0.5.

  # sigmoid (scalar float32 0.) |> get_item []
  - : float = 0.5
  # sigmoid (scalar float32 10.) |> get_item [] |> Float.round
  - : float = 1.
  # sigmoid (scalar float32 (-10.)) |> get_item [] |> Float.round
  - : float = 0.
Sourceval hard_sigmoid : ?alpha:float -> ?beta:float -> (float, 'a) t -> (float, 'a) t

hard_sigmoid ?alpha ?beta t applies piecewise linear sigmoid approximation.

Default alpha = 1/6, beta = 0.5.

Sourceval softplus : (float, 'a) t -> (float, 'a) t

softplus t applies smooth ReLU: log(1 + exp(x)).

Smooth approximation to ReLU. Always positive, differentiable everywhere.

  # softplus (scalar float32 0.) |> get_item [] |> Float.round
  - : float = 1.
  # softplus (scalar float32 100.) |> get_item [] |> Float.round
  - : float = infinity
Sourceval silu : (float, 'a) t -> (float, 'a) t

silu t applies Sigmoid Linear Unit: x * sigmoid(x).

Also called Swish. Smooth, non-monotonic activation.

  # silu (scalar float32 0.) |> get_item []
  - : float = 0.
  # silu (scalar float32 1.) |> get_item [] |> Float.round
  - : float = 1.
  # silu (scalar float32 (-1.)) |> get_item [] |> Float.round
  - : float = -0.
Sourceval hard_silu : (float, 'a) t -> (float, 'a) t

hard_silu t applies x * hard_sigmoid(x).

Piecewise linear approximation of SiLU. More efficient than SiLU.

  # let x = create float32 [| 3 |] [| -3.; 0.; 3. |] in
    hard_silu x
  - : (float, float32_elt) t = [-0, 0, 3]
Sourceval log_sigmoid : (float, 'a) t -> (float, 'a) t

log_sigmoid t computes log(sigmoid(x)).

Numerically stable version of log(1/(1+exp(-x))). Always negative.

  # log_sigmoid (scalar float32 0.) |> get_item [] |> Float.round
  - : float = -1.
  # log_sigmoid (scalar float32 100.) |> get_item [] |> Float.abs |> (fun x -> x < 0.001)
  - : bool = true
Sourceval leaky_relu : ?negative_slope:float -> (float, 'a) t -> (float, 'a) t

leaky_relu ?negative_slope t applies Leaky ReLU.

Default negative_slope = 0.01. Returns x if x > 0, else negative_slope * x.

Sourceval hard_tanh : (float, 'a) t -> (float, 'a) t

hard_tanh t clips values to -1, 1.

Linear in -1, 1, saturates outside. Cheaper than tanh.

  # let x = create float32 [| 5 |] [| -2.; -0.5; 0.; 0.5; 2. |] in
    hard_tanh x
  - : (float, float32_elt) t = [-1, -0.5, 0, 0.5, 1]
Sourceval elu : ?alpha:float -> (float, 'a) t -> (float, 'a) t

elu ?alpha t applies Exponential Linear Unit.

Default alpha = 1.0. Returns x if x > 0, else alpha * (exp(x) - 1). Smooth for x < 0, helps with vanishing gradients.

  # elu (scalar float32 1.) |> get_item []
  - : float = 1.
  # elu (scalar float32 0.) |> get_item []
  - : float = 0.
  # elu (scalar float32 (-1.)) |> get_item [] |> Float.round
  - : float = -1.
Sourceval selu : (float, 'a) t -> (float, 'a) t

selu t applies Scaled ELU with fixed alpha=1.67326, lambda=1.0507.

Self-normalizing activation. Preserves mean 0 and variance 1 in deep networks under certain conditions.

  # selu (scalar float32 0.) |> get_item []
  - : float = 0.
  # selu (scalar float32 1.) |> get_item [] |> Float.round
  - : float = 1.
Sourceval softmax : ?axes:int array -> (float, 'a) t -> (float, 'a) t

softmax ?axes t applies softmax normalization.

Default axis -1. Computes exp(x - max) / sum(exp(x - max)) for numerical stability. Output sums to 1 along specified axes.

  # let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
    softmax x |> to_array |> Array.map Float.round
  - : float array = [|0.; 0.; 1.|]
  # let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
    sum (softmax x) |> get_item []
  - : float = 1.
Sourceval gelu_approx : (float, 'a) t -> (float, 'a) t

gelu_approx t applies Gaussian Error Linear Unit approximation.

Smooth activation: x * Φ(x) where Φ is Gaussian CDF. This uses tanh approximation for efficiency.

  # gelu_approx (scalar float32 0.) |> get_item []
  - : float = 0.
  # gelu_approx (scalar float32 1.) |> get_item [] |> Float.round
  - : float = 1.
Sourceval softsign : (float, 'a) t -> (float, 'a) t

softsign t computes x / (|x| + 1).

Similar to tanh but computationally cheaper. Range (-1, 1).

  # let x = create float32 [| 3 |] [| -10.; 0.; 10. |] in
    softsign x
  - : (float, float32_elt) t = [-0.909091, 0, 0.909091]
Sourceval mish : (float, 'a) t -> (float, 'a) t

mish t applies Mish activation: x * tanh(softplus(x)).

Self-regularizing non-monotonic activation. Smoother than ReLU.

  # mish (scalar float32 0.) |> get_item [] |> Float.abs |> (fun x -> x < 0.001)
  - : bool = true
  # mish (scalar float32 (-10.)) |> get_item [] |> Float.round
  - : float = -0.

Convolution and Pooling

Neural network convolution and pooling operations.

Sourceval im2col : kernel_size:int array -> stride:int array -> dilation:int array -> padding:(int * int) array -> ('a, 'b) t -> ('a, 'b) t

im2col ~kernel_size ~stride ~dilation ~padding t extracts sliding local blocks from tensor.

Extracts patches of size kernel_size from the input tensor at the specified stride and dilation.

  • kernel_size: size of sliding blocks to extract
  • stride: step between consecutive blocks
  • dilation: spacing between kernel elements
  • padding: (before, after) padding for each spatial dimension

For a 4D input batch; channels; height; width, produces output shape batch; channels * kh * kw; num_patches_h; num_patches_w where kh, kw are kernel dimensions and num_patches depends on stride and padding.

  # let x = arange_f float32 0. 16. 1. |> reshape [| 1; 1; 4; 4 |] in
    im2col ~kernel_size:[|2; 2|] ~stride:[|1; 1|]
           ~dilation:[|1; 1|] ~padding:[|(0, 0); (0, 0)|] x |> shape
  - : int array = [|1; 4; 9|]
Sourceval col2im : output_size:int array -> kernel_size:int array -> stride:int array -> dilation:int array -> padding:(int * int) array -> ('a, 'b) t -> ('a, 'b) t

col2im ~output_size ~kernel_size ~stride ~dilation ~padding t combines sliding local blocks into tensor.

This is the inverse of im2col. Accumulates values from the unfolded representation back into spatial dimensions. Overlapping regions are summed.

  • output_size: target spatial dimensions height; width
  • kernel_size: size of sliding blocks
  • stride: step between consecutive blocks
  • dilation: spacing between kernel elements
  • padding: (before, after) padding for each spatial dimension

For input shape batch; channels * kh * kw; num_patches_h; num_patches_w, produces output batch; channels; height; width.

  # let unfolded = create float32 [| 1; 4; 3; 3 |] (Array.init 36 Float.of_int) in
    col2im ~output_size:[|4; 4|] ~kernel_size:[|2; 2|]
                ~stride:[|1; 1|] ~dilation:[|1; 1|]
                ~padding:[|(0, 0); (0, 0)|] unfolded |> shape
  - : int array = [|1; 4; 0; 4; 4|]
Sourceval correlate1d : ?groups:int -> ?stride:int -> ?padding_mode:[ `Full | `Same | `Valid ] -> ?dilation:int -> ?fillvalue:float -> ?bias:(float, 'a) t -> (float, 'a) t -> (float, 'a) t -> (float, 'a) t

correlate1d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 1D cross-correlation (no kernel flip).

  • x: input batch_size; channels_in; width
  • w: weights channels_out; channels_in/groups; kernel_width
  • bias: optional per-channel bias channels_out
  • groups: split input/output channels into groups (default 1)
  • stride: step between windows (default 1)
  • padding_mode: `Valid (no pad), `Same (preserve size), `Full (all overlaps)
  • dilation: spacing between kernel elements (default 1)
  • fillvalue: padding value (default 0.0)

Output width depends on padding:

  • `Valid: (width - dilation*(kernel-1) - 1)/stride + 1
  • `Same: width/stride (rounded up)
  • `Full: (width + dilation*(kernel-1) - 1)/stride + 1
  # let x = create float32 [| 1; 1; 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
    let w = create float32 [| 1; 1; 3 |] [| 1.; 0.; -1. |] in
    correlate1d x w |> shape
  - : int array = [|1; 1; 3|]
Sourceval correlate2d : ?groups:int -> ?stride:(int * int) -> ?padding_mode:[ `Full | `Same | `Valid ] -> ?dilation:(int * int) -> ?fillvalue:float -> ?bias:(float, 'a) t -> (float, 'a) t -> (float, 'a) t -> (float, 'a) t

correlate2d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 2D cross-correlation (no kernel flip).

  • x: input batch; channels_in; height; width
  • w: weights channels_out; channels_in/groups; kernel_h; kernel_w
  • bias: optional per-channel bias channels_out
  • stride: (stride_h, stride_w) step between windows (default (1,1))
  • dilation: (dilation_h, dilation_w) kernel spacing (default (1,1))
  • padding_mode: `Valid (no pad), `Same (preserve size), `Full (all overlaps)

Uses Winograd F(4,3) for 3×3 kernels with stride 1 when beneficial. For `Same` with even kernels, pads more on bottom/right (SciPy convention).

  # let image = ones float32 [| 1; 1; 5; 5 |] in
    let sobel_x = create float32 [| 1; 1; 3; 3 |] [| 1.; 0.; -1.; 2.; 0.; -2.; 1.; 0.; -1. |] in
    correlate2d image sobel_x |> shape
  - : int array = [|1; 1; 3; 3|]
Sourceval convolve1d : ?groups:int -> ?stride:int -> ?padding_mode:[< `Full | `Same | `Valid Valid ] -> ?dilation:int -> ?fillvalue:'a -> ?bias:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

convolve1d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 1D convolution (flips kernel before correlation).

Same parameters as correlate1d but flips kernel. For `Same` with even kernels, pads more on left (NumPy convention).

  # let x = create float32 [| 1; 1; 3 |] [| 1.; 2.; 3. |] in
    let w = create float32 [| 1; 1; 2 |] [| 4.; 5. |] in
    convolve1d x w
  - : (float, float32_elt) t = [[[13, 22]]]
Sourceval convolve2d : ?groups:int -> ?stride:(int * int) -> ?padding_mode:[< `Full | `Same | `Valid Valid ] -> ?dilation:(int * int) -> ?fillvalue:'a -> ?bias:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

convolve2d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 2D convolution (flips kernel before correlation).

Same parameters as correlate2d but flips kernel horizontally and vertically. For `Same` with even kernels, pads more on top/left.

  # let image = ones float32 [| 1; 1; 5; 5 |] in
    let gaussian = create float32 [| 1; 1; 3; 3 |] [| 1.; 2.; 1.; 2.; 4.; 2.; 1.; 2.; 1. |] in
    convolve2d image (mul_s gaussian (1. /. 16.)) |> shape
  - : int array = [|1; 1; 3; 3|]
Sourceval avg_pool1d : kernel_size:int -> ?stride:int -> ?dilation:int -> ?padding_spec:[< `Full | `Same | `Valid Valid ] -> ?ceil_mode:bool -> ?count_include_pad:bool -> (float, 'a) t -> (float, 'a) t

avg_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?count_include_pad x applies 1D average pooling.

  • kernel_size: pooling window size
  • stride: step between windows (default: kernel_size)
  • dilation: spacing between kernel elements (default 1)
  • padding_spec: same as convolution padding modes
  • ceil_mode: use ceiling for output size calculation (default false)
  • count_include_pad: include padding in average (default true)

Input shape: batch; channels; width Output width: (width + 2*pad - dilation*(kernel-1) - 1)/stride + 1

  # let x = create float32 [| 1; 1; 4 |] [| 1.; 2.; 3.; 4. |] in
    avg_pool1d ~kernel_size:2 x
  - : (float, float32_elt) t = [[[1.5, 3.5]]]
Sourceval avg_pool2d : kernel_size:(int * int) -> ?stride:(int * int) -> ?dilation:(int * int) -> ?padding_spec:[< `Full | `Same | `Valid Valid ] -> ?ceil_mode:bool -> ?count_include_pad:bool -> (float, 'a) t -> (float, 'a) t

avg_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?count_include_pad x applies 2D average pooling.

  • kernel_size: (height, width) of pooling window
  • stride: (stride_h, stride_w) (default: kernel_size)
  • dilation: (dilation_h, dilation_w) (default (1,1))
  • count_include_pad: whether padding contributes to denominator

Input shape: batch; channels; height; width

  # let x = create float32 [| 1; 1; 2; 2 |] [| 1.; 2.; 3.; 4. |] in
    avg_pool2d ~kernel_size:(2, 2) x
  - : (float, float32_elt) t = [[[[2.5]]]]
Sourceval max_pool1d : kernel_size:int -> ?stride:int -> ?dilation:int -> ?padding_spec:[< `Full | `Same | `Valid Valid ] -> ?ceil_mode:bool -> ?return_indices:bool -> ('a, 'b) t -> ('a, 'b) t * (int32, int32_elt) t option

max_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 1D max pooling.

  • return_indices: if true, also returns indices of max values for unpooling
  • Other parameters same as avg_pool1d

Returns (pooled_values, Some indices) if return_indices=true, otherwise (pooled_values, None). Indices are flattened positions in input.

  # let x = create float32 [| 1; 1; 4 |] [| 1.; 3.; 2.; 4. |] in
    let vals, idx = max_pool1d ~kernel_size:2 ~return_indices:true x in
    vals, idx
  - : (float, float32_elt) t * (int32, int32_elt) t option =
  ([[[3, 4]]], Some [[[1, 1]]])
Sourceval max_pool2d : kernel_size:(int * int) -> ?stride:(int * int) -> ?dilation:(int * int) -> ?padding_spec:[< `Full | `Same | `Valid Valid ] -> ?ceil_mode:bool -> ?return_indices:bool -> ('a, 'b) t -> ('a, 'b) t * (int32, int32_elt) t option

max_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 2D max pooling.

Parameters same as max_pool1d but for 2D. Indices encode flattened position within each pooling window.

  # let x = create float32 [| 1; 1; 4; 4 |]
      [| 1.; 2.; 5.; 6.; 3.; 4.; 7.; 8.; 9.; 10.; 13.; 14.; 11.; 12.; 15.; 16. |] in
    let vals, _ = max_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) x in
    vals
  - : (float, float32_elt) t = [[[[4, 8],
                                  [12, 16]]]]
Sourceval min_pool1d : kernel_size:int -> ?stride:int -> ?dilation:int -> ?padding_spec:[< `Full | `Same | `Valid Valid ] -> ?ceil_mode:bool -> ?return_indices:bool -> ('a, 'b) t -> ('a, 'b) t * (int32, int32_elt) t option

min_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 1D min pooling.

  • return_indices: if true, also returns indices of min values (currently returns None)
  • Other parameters same as avg_pool1d

Returns (pooled_values, None). Index tracking not yet implemented.

  # let x = create float32 [| 1; 1; 4 |] [| 4.; 2.; 3.; 1. |] in
    let vals, _ = min_pool1d ~kernel_size:2 x in
    vals
  - : (float, float32_elt) t = [[[2, 1]]]
Sourceval min_pool2d : kernel_size:(int * int) -> ?stride:(int * int) -> ?dilation:(int * int) -> ?padding_spec:[< `Full | `Same | `Valid Valid ] -> ?ceil_mode:bool -> ?return_indices:bool -> ('a, 'b) t -> ('a, 'b) t * (int32, int32_elt) t option

min_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 2D min pooling.

Parameters same as min_pool1d but for 2D. Commonly used for morphological erosion operations in image processing.

  # let x = create float32 [| 1; 1; 4; 4 |]
      [| 1.; 2.; 5.; 6.; 3.; 4.; 7.; 8.; 9.; 10.; 13.; 14.; 11.; 12.; 15.; 16. |] in
    let vals, _ = min_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) x in
    vals
  - : (float, float32_elt) t = [[[[1, 5],
                                  [9, 13]]]]
Sourceval max_unpool1d : (int, uint8_elt) t -> ('a, 'b) t -> kernel_size:int -> ?stride:int -> ?dilation:int -> ?padding_spec:[< `Full | `Same | `Valid Valid ] -> ?output_size_opt:int array -> unit -> (int, uint8_elt) t

max_unpool1d indices values ~kernel_size ?stride ?dilation ?padding_spec ?output_size_opt () reverses max pooling.

  • indices: indices from max_pool1d with return_indices=true
  • values: pooled values to place at indexed positions
  • kernel_size, stride, dilation, padding_spec: must match original pool
  • output_size_opt: exact output shape (inferred if not provided)

Places values at positions indicated by indices, fills rest with zeros. Output size computed from input unless explicitly specified.

  # let x = create float32 [| 1; 1; 4 |] [| 1.; 3.; 2.; 4. |] in
    let pooled, _ = max_pool1d ~kernel_size:2 x in
    pooled
  - : (float, float32_elt) t = [[[3, 4]]]
Sourceval max_unpool2d : (int, uint8_elt) t -> ('a, 'b) t -> kernel_size:(int * int) -> ?stride:(int * int) -> ?dilation:(int * int) -> ?padding_spec:[< `Full | `Same | `Valid Valid ] -> ?output_size_opt:int array -> unit -> (int, uint8_elt) t

max_unpool2d indices values ~kernel_size ?stride ?dilation ?padding_spec ?output_size_opt () reverses 2D max pooling.

Same as max_unpool1d but for 2D. Indices encode position within each pooling window. Useful for architectures like segmentation networks that need to "remember" where maxima came from.

  # let x = create float32 [| 1; 1; 4; 4 |]
      [| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.;
         9.; 10.; 11.; 12.; 13.; 14.; 15.; 16. |] in
    let pooled, _ = max_pool2d ~kernel_size:(2,2) x in
    pooled
  - : (float, float32_elt) t = [[[[6, 8],
                                  [14, 16]]]]
Sourceval one_hot : num_classes:int -> ('a, 'b) t -> (int, uint8_elt) t

one_hot ~num_classes indices creates one-hot encoding.

Adds new last dimension of size num_classes. Values must be in [0, num_classes). Out-of-range indices produce zero vectors.

  # let indices = create int32 [| 3 |] [| 0l; 1l; 3l |] in
    one_hot ~num_classes:4 indices
  - : (int, uint8_elt) t = [[1, 0, 0, 0],
                            [0, 1, 0, 0],
                            [0, 0, 0, 1]]
  # let indices = create int32 [| 2; 2 |] [| 0l; 2l; 1l; 0l |] in
    one_hot ~num_classes:3 indices |> shape
  - : int array = [|2; 2; 3|]

Iteration and Mapping

Functions to iterate over and transform arrays.

Sourceval map_item : ('a -> 'a) -> ('a, 'b) t -> ('a, 'b) t

map_item f t applies f to each element.

Operates on contiguous data directly. Type-preserving only.

Sourceval iter_item : ('a -> unit) -> ('a, 'b) t -> unit

iter_item f t applies f to each element for side effects.

Sourceval fold_item : ('a -> 'b -> 'a) -> 'a -> ('b, 'c) t -> 'a

fold_item f init t folds f over elements.

Sourceval map : (('a, 'b) t -> ('a, 'b) t) -> ('a, 'b) t -> ('a, 'b) t

map f t applies tensor function f to each element as scalar tensor.

Sourceval iter : (('a, 'b) t -> unit) -> ('a, 'b) t -> unit

iter f t applies tensor function f to each element.

Sourceval fold : ('a -> ('b, 'c) t -> 'a) -> 'a -> ('b, 'c) t -> 'a

fold f init t folds tensor function over elements.

Printing and Display

Functions to display arrays and convert to strings.

Sourceval pp_data : Format.formatter -> ('a, 'b) t -> unit

pp_data fmt t pretty-prints tensor data.

Sourceval format_to_string : (Format.formatter -> 'a -> unit) -> 'a -> string

format_to_string pp x converts using pretty-printer.

Sourceval print_with_formatter : (Format.formatter -> 'a -> unit) -> 'a -> unit

print_with_formatter pp x prints using formatter.

Sourceval data_to_string : ('a, 'b) t -> string

data_to_string t converts tensor data to string.

Sourceval print_data : ('a, 'b) t -> unit

print_data t prints tensor data to stdout.

Sourceval pp_dtype : Format.formatter -> ('a, 'b) dtype -> unit

pp_dtype fmt dt pretty-prints dtype.

Sourceval dtype_to_string : ('a, 'b) dtype -> string

dtype_to_string dt converts dtype to string.

Sourceval shape_to_string : int array -> string

shape_to_string shape formats shape as "2x3x4".

Sourceval pp_shape : Format.formatter -> int array -> unit

pp_shape fmt shape pretty-prints shape.

Sourceval pp : Format.formatter -> ('a, 'b) t -> unit

pp fmt t pretty-prints tensor info and data.

Sourceval print : ('a, 'b) t -> unit

print t prints tensor info and data to stdout.

Sourceval to_string : ('a, 'b) t -> string

to_string t converts tensor info and data to string.

OCaml

Innovation. Community. Security.