package neural_nets_lib

  1. Overview
  2. Docs

Source file ppx_op.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
open Base
open Ppxlib
open Ppx_arrayjit.Ppx_helper
open Ppx_shared

let ndarray_op ?ident_label ?axis_labels expr =
  let loc = expr.pexp_loc in
  let values, batch_dims, output_dims, input_dims = ndarray_constant expr in
  let edims dims = Ast_builder.Default.elist ~loc dims in
  let op =
    match axis_labels with
    | None -> [%expr TDSL.ndarray]
    | Some axis_labels -> [%expr TDSL.ndarray ~axis_labels:[%e axis_labels]]
  in
  [%expr
    [%e op] ~label:[%e opt_pat2string_list ~loc ident_label] ~batch_dims:[%e edims batch_dims]
      ~input_dims:[%e edims input_dims] ~output_dims:[%e edims output_dims] [%e values]]

let make_vb ?value ~loc ~str_loc ~ident string =
  let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in
  let value = match value with Some c -> [%expr Some [%e c]] | None -> [%expr None] in
  let v = [%expr TDSL.param ?values:[%e value] [%e string]] in
  let vb = Ast_helper.Vb.mk ~loc pat v in
  (pat, vb)

let make_vb_dims ~loc ~str_loc ~ident ~dims ~dims_loc string =
  let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in
  let dims =
    let loc = dims_loc in
    List.fold_right dims ~init:[%expr []] ~f:(fun d ds -> [%expr [%e d] :: [%e ds]])
  in
  let v = [%expr TDSL.param ~output_dims:[%e dims] [%e string]] in
  let vb = Ast_helper.Vb.mk ~loc pat v in
  (pat, vb)

let make_vb_nd ~loc ~str_loc ?axis_labels ~ident ~init_nd string =
  let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in
  let values, batch_dims, output_dims, input_dims = ndarray_constant init_nd in
  let v =
    if not @@ List.is_empty batch_dims then
      Ast_builder.Default.pexp_extension ~loc
      @@ Location.error_extensionf ~loc
           "ppx_ocannl param cannot have batch dims: define a constant or remove the array syntax."
    else
      let edims dims = Ast_builder.Default.elist ~loc dims in
      let op =
        match axis_labels with
        | None -> [%expr TDSL.param]
        | Some axis_labels -> [%expr TDSL.param ~axis_labels:[%e axis_labels]]
      in
      [%expr
        [%e op] ~input_dims:[%e edims input_dims] ~output_dims:[%e edims output_dims] ~values:[%e values]
          [%e string]]
  in
  let vb = Ast_helper.Vb.mk ~loc pat v in
  (pat, vb)

let rec translate ?ident_label expr =
  let loc = expr.pexp_loc in
  match expr with
  | { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
      (no_vbs, [%expr TDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] [%e expr]])
  | { pexp_desc = Pexp_constant (Pconst_integer _); _ } ->
      (no_vbs, [%expr TDSL.number (Float.of_int [%e expr])])
  | [%expr
      [%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
        [%e? { pexp_desc = Pexp_constant (Pconst_float _); _ } as f]] ->
      let axis = Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) in
      ( no_vbs,
        [%expr TDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] ~axis_label:[%e axis] [%e f]] )
  | [%expr
      [%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
        [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
      let axis = Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) in
      ( no_vbs,
        [%expr
          TDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] ~axis_label:[%e axis]
            (Float.of_int [%e i])] )
  | [%expr
      [%e? expr1]
      *+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec] [%e? expr2]]
    when String.contains spec_str '>' ->
      let vbs1, e1 = translate expr1 in
      let vbs2, e2 = translate expr2 in
      ( reduce_vbss [ vbs1; vbs2 ],
        [%expr TDSL.einsum ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e e1] [%e e2]] )
  | [%expr [%e? expr1] ++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]]
    when String.contains spec_str '>' ->
      let vbs1, e1 = translate expr1 in
      (vbs1, [%expr TDSL.einsum1 ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e e1]])
  | [%expr
      [%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
        [%e?
          ( { pexp_desc = Pexp_constant (Pconst_integer _); pexp_loc = dims_loc; _ }
          | { pexp_desc = Pexp_ident _; pexp_loc = dims_loc; _ } ) as d]] ->
      let pat, vb = make_vb_dims ~loc ~str_loc ~ident ~dims:[ d ] ~dims_loc s in
      (Map.singleton (module String) ident vb, pat2expr pat)
  | [%expr
      [%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
        [%e?
          ({ pexp_desc = Pexp_array _; _ } | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ })
          as init_nd]] ->
      let pat, vb = make_vb_nd ~loc ~str_loc ~ident ~init_nd s in
      (Map.singleton (module String) ident vb, pat2expr pat)
  | [%expr
      [%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
        [%e? { pexp_desc = Pexp_tuple dims; pexp_loc = dims_loc; _ }]] ->
      let pat, vb = make_vb_dims ~loc ~str_loc ~ident ~dims ~dims_loc s in
      (Map.singleton (module String) ident vb, pat2expr pat)
  | { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } ->
      let pat, vb = make_vb ~loc ~str_loc ~ident expr in
      (Map.singleton (module String) ident vb, pat2expr pat)
  | { pexp_desc = Pexp_array _; _ } | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
      (no_vbs, ndarray_op ?ident_label expr)
  | [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
      (* We need to hardcode these two patterns to prevent the numbers from being converted to tensors. *)
      let vbs, e1 = translate expr1 in
      ( vbs,
        [%expr TDSL.O.( **. ) ~label:[%e opt_pat2string_list ~loc ident_label] [%e e1] (Float.of_int [%e i])]
      )
  | [%expr [%e? expr1] **. [%e? expr2]] ->
      let vbs, e1 = translate expr1 in
      (vbs, [%expr TDSL.O.( **. ) ~label:[%e opt_pat2string_list ~loc ident_label] [%e e1] [%e expr2]])
  | [%expr [%e? expr1] [%e? expr2] [%e? expr3]] ->
      let vbs1, e1 = translate ?ident_label expr1 in
      let vbs2, e2 = translate expr2 in
      let vbs3, e3 = translate expr3 in
      (reduce_vbss [ vbs1; vbs2; vbs3 ], [%expr [%e e1] [%e e2] [%e e3]])
  | [%expr [%e? expr1] [%e? expr2]] ->
      let vbs1, e1 = translate ?ident_label expr1 in
      let vbs2, e2 = translate expr2 in
      (Map.merge_skewed vbs1 vbs2 ~combine:(fun ~key:_ _v1 v2 -> v2), [%expr [%e e1] [%e e2]])
  | [%expr fun ~config [%p? pat1] [%p? pat2] -> [%e? body]] ->
      (* TODO(#38): generalize config to any number of labeled arguments with any labels. *)
      let vbs, body = translate ?ident_label body in
      (no_vbs, [%expr fun ~config -> [%e let_opt ~loc vbs [%expr fun [%p pat1] [%p pat2] -> [%e body]]]])
  | [%expr fun ~config [%p? pat] -> [%e? body]] ->
      (* TODO(#38): generalize config to any number of labeled arguments with any labels. *)
      let vbs, body = translate ?ident_label body in
      (no_vbs, [%expr fun ~config -> [%e let_opt ~loc vbs [%expr fun [%p pat] -> [%e body]]]])
  | [%expr fun [%p? pat1] [%p? pat2] -> [%e? body]] ->
      let vbs, body = translate ?ident_label body in
      (vbs, [%expr fun [%p pat1] [%p pat2] -> [%e body]])
  | [%expr fun [%p? pat] -> [%e? body]] ->
      let vbs, body = translate ?ident_label body in
      (vbs, [%expr fun [%p pat] -> [%e body]])
  | [%expr
      while [%e? test_expr] do
        [%e? body_expr]
      done] ->
      let vbs, body = translate ?ident_label body_expr in
      ( vbs,
        [%expr
          while [%e test_expr] do
            [%e body]
          done] )
  | [%expr
      for [%p? pat] = [%e? init] to [%e? final] do
        [%e? body_expr]
      done] ->
      let vbs, body = translate ?ident_label body_expr in
      ( vbs,
        [%expr
          for [%p pat] = [%e init] to [%e final] do
            [%e body]
          done] )
  | [%expr
      for [%p? pat] = [%e? init] downto [%e? final] do
        [%e? body_expr]
      done] ->
      let vbs, body = translate ?ident_label body_expr in
      ( vbs,
        [%expr
          for [%p pat] = [%e init] downto [%e final] do
            [%e body]
          done] )
  | [%expr
      [%e? expr1];
      [%e? expr2]] ->
      let vbs1, e1 = translate expr1 in
      let vbs2, e2 = translate ?ident_label expr2 in
      ( reduce_vbss [ vbs1; vbs2 ],
        [%expr
          [%e e1];
          [%e e2]] )
  | [%expr if [%e? expr1] then [%e? expr2] else [%e? expr3]] ->
      let vbs2, e2 = translate ?ident_label expr2 in
      let vbs3, e3 = translate ?ident_label expr3 in
      (reduce_vbss [ vbs2; vbs3 ], [%expr if [%e expr1] then [%e e2] else [%e e3]])
  | [%expr if [%e? expr1] then [%e? expr2]] ->
      let vbs2, e2 = translate ?ident_label expr2 in
      (vbs2, [%expr if [%e expr1] then [%e e2]])
  | { pexp_desc = Pexp_match (expr1, cases); _ } ->
      let vbss, cases =
        List.unzip
        @@ List.map cases ~f:(fun ({ pc_rhs; _ } as c) ->
               let vbs, pc_rhs = translate ?ident_label pc_rhs in
               (vbs, { c with pc_rhs }))
      in
      (reduce_vbss vbss, { expr with pexp_desc = Pexp_match (expr1, cases) })
  | { pexp_desc = Pexp_let (recflag, bindings, body); _ } ->
      let vbss1, bindings =
        List.unzip
        @@ List.map bindings ~f:(fun binding ->
               let vbs, pvb_expr = translate ~ident_label:binding.pvb_pat binding.pvb_expr in
               (vbs, { binding with pvb_expr }))
      in
      let vbs2, body = translate ?ident_label body in
      let all_bindings = (Map.data @@ reduce_vbss vbss1) @ bindings @ Map.data vbs2 in
      (no_vbs, { expr with pexp_desc = Pexp_let (recflag, all_bindings, body) })
  | { pexp_desc = Pexp_open (decl, body); _ } ->
      let vbs, body = translate ?ident_label body in
      (vbs, { expr with pexp_desc = Pexp_open (decl, body) })
  | { pexp_desc = Pexp_letmodule (name, module_expr, body); _ } ->
      let vbs, body = translate ?ident_label body in
      (vbs, { expr with pexp_desc = Pexp_letmodule (name, module_expr, body) })
  | { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_operator op_ident ->
      (no_vbs, [%expr [%e expr] ~label:[%e opt_pat2string_list ~loc ident_label]])
  | expr -> (no_vbs, expr)

let expr_expander ~loc ~path:_ payload =
  match payload with
  | { pexp_desc = Pexp_let (recflag, bindings, body); _ } ->
      (* We are at the %op annotation level: do not tranlsate the body. *)
      let vbss, bindings =
        List.unzip
        @@ List.map bindings ~f:(fun vb ->
               let vbs, v = translate ~ident_label:vb.pvb_pat vb.pvb_expr in
               ( vbs,
                 {
                   vb with
                   pvb_expr =
                     [%expr
                       let open! TDSL.O in
                       [%e v]];
                 } ))
      in
      let expr = { payload with pexp_desc = Pexp_let (recflag, bindings, body) } in
      let_opt ~loc (reduce_vbss vbss) expr
  | expr ->
      let vbs, expr = translate expr in
      let_opt ~loc vbs expr

let flatten_str ~loc ~path:_ items =
  match items with
  | [ x ] -> x
  | _ ->
      Ast_helper.Str.include_
        { pincl_mod = Ast_helper.Mod.structure items; pincl_loc = loc; pincl_attributes = [] }

let translate_str ({ pstr_desc; pstr_loc = loc; _ } as str) =
  match pstr_desc with
  | Pstr_eval (expr, attrs) ->
      let vbs, expr = translate expr in
      let expr = let_opt ~loc vbs expr in
      { str with pstr_desc = Pstr_eval (expr, attrs) }
  | Pstr_value (recf, bindings) ->
      let f vb =
        let loc = vb.pvb_loc in
        let ident_label = vb.pvb_pat in
        let vbs, v = translate ~ident_label vb.pvb_expr in
        let is_unused = match ident_label with [%pat? _] -> true | _ -> false in
        let v = let_opt ~loc vbs v in
        {
          vb with
          pvb_expr =
            [%expr
              let open! TDSL.O in
              [%e if is_unused then [%expr Tensor.with_unchanged_roots ~f:(fun () -> [%e v])] else v]];
        }
      in
      { str with pstr_desc = Pstr_value (recf, List.map bindings ~f) }
  | _ -> str

let str_expander ~loc ~path (payload : structure_item list) =
  flatten_str ~loc ~path @@ List.map payload ~f:translate_str
OCaml

Innovation. Community. Security.