package nx

  1. Overview
  2. Docs

Source file ops_threefry.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
open Nx_core.Dtype
module Shape = Nx_core.Shape
open Internal

(* Threefry 2x32 Core Implementation *)
module Threefry_impl = struct
  let ks_parity_32 = 0x1BD11BDA_l
  let r_2x32 = [| 13; 15; 26; 6; 17; 29; 16; 24 |]

  let rotl32 x n =
    let n = n land 31 in
    Int32.(logor (shift_left x n) (shift_right_logical x (32 - n)))

  let threefry2x32_20_rounds (c0 : int32) (c1 : int32) (k0 : int32) (k1 : int32)
      : int32 * int32 =
    let x0 = ref c0 in
    let x1 = ref c1 in
    let keys = [| k0; k1; Int32.logxor ks_parity_32 (Int32.logxor k0 k1) |] in

    for r = 0 to 19 do
      if r mod 4 = 0 then (
        let s_div_4 = r / 4 in
        x0 := Int32.add !x0 keys.(s_div_4 mod 3);
        x1 := Int32.add !x1 keys.((s_div_4 + 1) mod 3);
        x1 := Int32.add !x1 (Int32.of_int s_div_4));
      x0 := Int32.add !x0 !x1;
      x1 := rotl32 !x1 r_2x32.(r mod 8);
      x1 := Int32.logxor !x1 !x0
    done;

    let s_div_4_final = 20 / 4 in
    x0 := Int32.add !x0 keys.(s_div_4_final mod 3);
    x1 := Int32.add !x1 keys.((s_div_4_final + 1) mod 3);
    x1 := Int32.add !x1 (Int32.of_int s_div_4_final);
    (!x0, !x1)
end

let kernel_threefry_int32 (data_t : (int32, int32_elt) t)
    (seed_t : (int32, int32_elt) t) (out_t : (int32, int32_elt) t) start_idx
    end_idx =
  let data_buf = buffer data_t in
  let seed_buf = buffer seed_t in
  let out_buf = buffer out_t in
  let c1_fixed = 0l in
  let k1_fixed = 0xCAFEBABEl in

  if is_c_contiguous data_t && is_c_contiguous seed_t then (
    let data_offset = offset data_t in
    let seed_offset = offset seed_t in
    let out_offset = offset out_t in

    let data_base = data_offset + start_idx in
    let seed_base = seed_offset + start_idx in
    let out_base = out_offset + start_idx in
    let i = ref 0 in
    let n = end_idx - start_idx in
    while !i + 3 < n do
      let i0 = !i and i1 = !i + 1 and i2 = !i + 2 and i3 = !i + 3 in

      let d_val0 = Bigarray.Array1.unsafe_get data_buf (data_base + i0) in
      let s_val0 = Bigarray.Array1.unsafe_get seed_buf (seed_base + i0) in
      let res0_0, _ =
        Threefry_impl.threefry2x32_20_rounds d_val0 c1_fixed s_val0 k1_fixed
      in
      Bigarray.Array1.unsafe_set out_buf (out_base + i0) res0_0;

      let d_val1 = Bigarray.Array1.unsafe_get data_buf (data_base + i1) in
      let s_val1 = Bigarray.Array1.unsafe_get seed_buf (seed_base + i1) in
      let res0_1, _ =
        Threefry_impl.threefry2x32_20_rounds d_val1 c1_fixed s_val1 k1_fixed
      in
      Bigarray.Array1.unsafe_set out_buf (out_base + i1) res0_1;

      let d_val2 = Bigarray.Array1.unsafe_get data_buf (data_base + i2) in
      let s_val2 = Bigarray.Array1.unsafe_get seed_buf (seed_base + i2) in
      let res0_2, _ =
        Threefry_impl.threefry2x32_20_rounds d_val2 c1_fixed s_val2 k1_fixed
      in
      Bigarray.Array1.unsafe_set out_buf (out_base + i2) res0_2;

      let d_val3 = Bigarray.Array1.unsafe_get data_buf (data_base + i3) in
      let s_val3 = Bigarray.Array1.unsafe_get seed_buf (seed_base + i3) in
      let res0_3, _ =
        Threefry_impl.threefry2x32_20_rounds d_val3 c1_fixed s_val3 k1_fixed
      in
      Bigarray.Array1.unsafe_set out_buf (out_base + i3) res0_3;

      i := !i + 4
    done;
    while !i < n do
      let current_idx = !i in
      let d_val =
        Bigarray.Array1.unsafe_get data_buf (data_base + current_idx)
      in
      let s_val =
        Bigarray.Array1.unsafe_get seed_buf (seed_base + current_idx)
      in
      let res0, _ =
        Threefry_impl.threefry2x32_20_rounds d_val c1_fixed s_val k1_fixed
      in
      Bigarray.Array1.unsafe_set out_buf (out_base + current_idx) res0;
      incr i
    done)
  else
    let out_shape = shape out_t in
    let data_strides = strides data_t in
    let seed_strides = strides seed_t in
    let data_offset = offset data_t in
    let seed_offset = offset seed_t in

    (* Pre-allocate work array *)
    let md_index = Array.make (Array.length out_shape) 0 in

    for k = start_idx to end_idx - 1 do
      Shape.unravel_index_into k out_shape md_index;

      let data_lin = Shape.ravel_index md_index data_strides in
      let seed_lin = Shape.ravel_index md_index seed_strides in

      let d_val =
        Bigarray.Array1.unsafe_get data_buf (data_offset + data_lin)
      in
      let s_val =
        Bigarray.Array1.unsafe_get seed_buf (seed_offset + seed_lin)
      in

      let res0, _ =
        Threefry_impl.threefry2x32_20_rounds d_val c1_fixed s_val k1_fixed
      in
      Bigarray.Array1.unsafe_set out_buf k res0
    done

let threefry (context : context) (data_t : (int32, int32_elt) t)
    (seed_t : (int32, int32_elt) t) (out_t : (int32, int32_elt) t) : unit =
  let size = size out_t in
  if size = 0 then ()
  else
    Parallel.parallel_for context.pool 0 (size - 1) (fun start_idx end_idx ->
        kernel_threefry_int32 data_t seed_t out_t start_idx end_idx)
OCaml

Innovation. Community. Security.