package arrayjit

  1. Overview
  2. Docs

Source file backend_utils.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
open Base
module Lazy = Utils.Lazy
module Debug_runtime = Utils.Debug_runtime

let _get_local_debug_runtime = Utils._get_local_debug_runtime

[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

module Types = struct
  type 'context routine = {
    context : 'context;
    schedule : Tnode.task;
    bindings : Indexing.lowered_bindings;
    name : string;
  }
  [@@deriving sexp_of]

  type config = Physical_devices_only | For_parallel_copying | Most_parallel_devices
  [@@deriving equal, sexp, variants]

  type merge_buffer_use = No | Streaming | Copy [@@deriving equal, sexp]

  type param_source =
    | Log_file_name
    | Merge_buffer
    | Param_ptr of Tnode.t
    | Static_idx of Indexing.static_symbol
  [@@deriving sexp_of]
end

module Tn = Tnode

module C_syntax (B : sig
  val for_lowereds : Low_level.optimized array

  type ctx_array

  val opt_ctx_arrays : ctx_array Map.M(Tnode).t option
  val hardcoded_context_ptr : (ctx_array -> string) option
  val is_in_context : Low_level.traced_array -> bool
  val host_ptrs_for_readonly : bool
  val logs_to_stdout : bool
  val main_kernel_prefix : string
  val kernel_prep_line : string
  val include_lines : string list
  val typ_of_prec : Ops.prec -> string
  val binop_syntax : Ops.prec -> Ops.binop -> string * string * string
  val unop_syntax : Ops.prec -> Ops.unop -> string * string
  val convert_precision : from:Ops.prec -> to_:Ops.prec -> string * string
end) =
struct
  open Types

  let get_ident =
    Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.for_lowereds ~f:(fun l -> l.llc)

  let pp_zero_out ppf tn =
    Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn

  (* let pp_semi ppf () = Stdlib.Format.fprintf ppf ";@ " *)
  let pp_comma ppf () = Stdlib.Format.fprintf ppf ",@ "

  (* let pp_symbol ppf sym = Stdlib.Format.fprintf ppf "%s" @@ Indexing.symbol_ident sym *)
  let pp_index ppf sym = Stdlib.Format.fprintf ppf "%s" @@ Indexing.symbol_ident sym

  let pp_index_axis ppf = function
    | Indexing.Iterator it -> pp_index ppf it
    | Fixed_idx i when i < 0 -> Stdlib.Format.fprintf ppf "(%d)" i
    | Fixed_idx i -> Stdlib.Format.fprintf ppf "%d" i

  let pp_array_offset ppf (idcs, dims) =
    let open Stdlib.Format in
    assert (not @@ Array.is_empty idcs);
    for _ = 0 to Array.length idcs - 3 do
      fprintf ppf "@[<1>("
    done;
    for i = 0 to Array.length idcs - 1 do
      let dim = dims.(i) in
      if i = 0 then fprintf ppf "%a" pp_index_axis idcs.(i)
      else if i = Array.length idcs - 1 then fprintf ppf " * %d + %a" dim pp_index_axis idcs.(i)
      else fprintf ppf " * %d +@ %a@;<0 -1>)@]" dim pp_index_axis idcs.(i)
    done

  let array_offset_to_string (idcs, dims) =
    let b = Buffer.create 32 in
    let ppf = Stdlib.Format.formatter_of_buffer b in
    pp_array_offset ppf (idcs, dims);
    Stdlib.Format.pp_print_flush ppf ();
    Buffer.contents b

  (* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim
     -> idx + (offset * dim)) *)
  let%diagn_sexp compile_globals ppf =
    let open Stdlib.Format in
    let is_global = Hash_set.create (module Tn) in
    fprintf ppf {|@[<v 0>%a@,/* Global declarations. */@,|} (pp_print_list pp_print_string)
      B.include_lines;
    Array.iter B.for_lowereds ~f:(fun l ->
        Hashtbl.iter l.Low_level.traced_store ~f:(fun (node : Low_level.traced_array) ->
            if not @@ Hash_set.mem is_global node.tn then
              let in_ctx : bool = B.is_in_context node in
              let ctx_ptr = B.hardcoded_context_ptr in
              let mem : (Tn.memory_mode * int) option = node.tn.memory_mode in
              match
                (in_ctx, ctx_ptr, B.opt_ctx_arrays, B.host_ptrs_for_readonly, mem, node.read_only)
              with
              | true, Some get_ptr, Some ctx_arrays, _, _, _ ->
                  let ident = get_ident node.tn in
                  let ctx_array =
                    Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays node.tn
                  in
                  fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array;
                  Hash_set.add is_global node.tn
              | false, _, _, true, Some (Hosted _, _), true ->
                  (* In-context nodes to read directly from host would be error prone. *)
                  let nd = Option.value_exn ~here:[%here] @@ Lazy.force node.tn.array in
                  fprintf ppf "#define %s (%s)@," (get_ident node.tn) (Ndarray.c_ptr_to_string nd);
                  Hash_set.add is_global node.tn
              | _ -> ()));
    fprintf ppf "@,@]";
    is_global

  let compile_main ~traced_store ppf llc : unit =
    let open Stdlib.Format in
    let visited = Hash_set.create (module Tn) in
    let rec pp_ll ppf c : unit =
      match c with
      | Low_level.Noop -> ()
      | Seq (c1, c2) ->
          (* Note: no separator. Filter out some entries known to not generate code to avoid
             whitespace. *)
          fprintf ppf "@[<v 0>%a@]" (pp_print_list pp_ll)
            (List.filter [ c1; c2 ] ~f:(function Noop -> false | _ -> true))
      | For_loop { index = i; from_; to_; body; trace_it = _ } ->
          fprintf ppf "@[<2>for (int@ %a = %d;@ %a <= %d;@ ++%a) {@ " pp_index i from_ pp_index i
            to_ pp_index i;
          if Utils.debug_log_from_routines () then
            if B.logs_to_stdout then
              fprintf ppf {|printf(@[<h>"%s%%d: index %a = %%d\n",@] log_id, %a);@ |}
                !Utils.captured_log_prefix pp_index i pp_index i
            else
              fprintf ppf {|fprintf(log_file,@ @[<h>"index %a = %%d\n",@] %a);@ |} pp_index i
                pp_index i;
          fprintf ppf "%a@;<1 -2>}@]@," pp_ll body
      | Zero_out tn ->
          let traced = Low_level.(get_node traced_store tn) in
          (* The initialization will be emitted at the end of compile_proc. *)
          if Hash_set.mem visited tn then pp_zero_out ppf tn else assert traced.zero_initialized
      | Set { tn; idcs; llv; debug } ->
          Hash_set.add visited tn;
          let ident = get_ident tn in
          let dims = Lazy.force tn.dims in
          let loop_f = pp_float @@ Lazy.force tn.prec in
          let loop_debug_f = debug_float @@ Lazy.force tn.prec in
          let num_closing_braces = pp_top_locals ppf llv in
          let num_typ = B.typ_of_prec @@ Lazy.force tn.prec in
          if Utils.debug_log_from_routines () then (
            fprintf ppf "@[<2>{@ @[<2>%s new_set_v =@ %a;@]@ " num_typ loop_f llv;
            let v_code, v_idcs = loop_debug_f llv in
            let pp_args =
              pp_print_list @@ fun ppf -> function
              | `Accessor idx ->
                  pp_comma ppf ();
                  pp_array_offset ppf idx
              | `Value v ->
                  pp_comma ppf ();
                  pp_print_string ppf v
            in
            let offset = (idcs, dims) in
            if B.logs_to_stdout then (
              fprintf ppf {|@[<7>printf(@[<h>"%s%%d: # %s\n", log_id@]);@]@ |}
                !Utils.captured_log_prefix
                (String.substr_replace_all debug ~pattern:"\n" ~with_:"$");
              fprintf ppf
                {|@[<7>printf(@[<h>"%s%%d: %s[%%u]{=%%g} = %%g = %s\n",@]@ log_id,@ %a,@ %s[%a],@ new_set_v%a);@]@ |}
                !Utils.captured_log_prefix ident v_code pp_array_offset offset ident pp_array_offset
                offset pp_args v_idcs)
            else (
              fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"# %s\n"@]);@]@ |}
                (String.substr_replace_all debug ~pattern:"\n" ~with_:"$");
              fprintf ppf
                {|@[<7>fprintf(log_file,@ @[<h>"%s[%%u]{=%%g} = %%g = %s\n",@]@ %a,@ %s[%a],@ new_set_v%a);@]@ |}
                ident v_code pp_array_offset offset ident pp_array_offset offset pp_args v_idcs);
            if not B.logs_to_stdout then fprintf ppf "fflush(log_file);@ ";
            fprintf ppf "@[<2>%s[@,%a] =@ new_set_v;@]@;<1 -2>}@]@ " ident pp_array_offset
              (idcs, dims))
          else
            (* No idea why adding any cut hint at the end of the assign line breaks formatting! *)
            fprintf ppf "@[<2>%s[@,%a] =@ %a;@]@ " ident pp_array_offset (idcs, dims) loop_f llv;
          for _ = 1 to num_closing_braces do
            fprintf ppf "@]@ }@,"
          done
      | Comment message ->
          if Utils.debug_log_from_routines () then
            if B.logs_to_stdout then
              fprintf ppf {|printf(@[<h>"%s%%d: COMMENT: %s\n",@] log_id);@ |}
                !Utils.captured_log_prefix
                (String.substr_replace_all ~pattern:"%" ~with_:"%%" message)
            else
              fprintf ppf {|fprintf(log_file,@ @[<h>"COMMENT: %s\n"@]);@ |}
                (String.substr_replace_all ~pattern:"%" ~with_:"%%" message)
          else fprintf ppf "/* %s */@ " message
      | Staged_compilation callback -> callback ()
      | Set_local (Low_level.{ scope_id; tn = { prec; _ } }, value) ->
          let num_closing_braces = pp_top_locals ppf value in
          fprintf ppf "@[<2>v%d =@ %a;@]" scope_id (pp_float @@ Lazy.force prec) value;
          for _ = 1 to num_closing_braces do
            fprintf ppf "@]@ }@,"
          done
    and pp_top_locals ppf (vcomp : Low_level.float_t) : int =
      match vcomp with
      | Local_scope { id = { scope_id = i; tn = { prec; _ } }; body; orig_indices = _ } ->
          let num_typ = B.typ_of_prec @@ Lazy.force prec in
          (* Arrays are initialized to 0 by default. However, there is typically an explicit
             initialization for virtual nodes. *)
          fprintf ppf "@[<2>{@ %s v%d = 0;@ " num_typ i;
          pp_ll ppf body;
          pp_print_space ppf ();
          1
      | Get_local _ | Get_global _ | Get _ | Constant _ | Embed_index _ -> 0
      | Binop (Arg1, v1, _v2) -> pp_top_locals ppf v1
      | Binop (Arg2, _v1, v2) -> pp_top_locals ppf v2
      | Binop (_, v1, v2) -> pp_top_locals ppf v1 + pp_top_locals ppf v2
      | Unop (_, v) -> pp_top_locals ppf v
    and pp_float (prec : Ops.prec) ppf value =
      let loop = pp_float prec in
      match value with
      | Local_scope { id; _ } ->
          (* Embedding of Local_scope is done by pp_top_locals. *)
          loop ppf @@ Get_local id
      | Get_local id ->
          let prefix, postfix = B.convert_precision ~from:(Lazy.force id.tn.prec) ~to_:prec in
          fprintf ppf "%sv%d%s" prefix id.scope_id postfix
      | Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) ->
          let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
          let prefix, postfix = B.convert_precision ~from:(Lazy.force tn.prec) ~to_:prec in
          fprintf ppf "@[<2>%smerge_buffer[%a@;<0 -2>]%s@]" prefix pp_array_offset
            (idcs, Lazy.force tn.dims)
            postfix
      | Get_global _ -> failwith "C_syntax: Get_global / FFI NOT IMPLEMENTED YET"
      | Get (tn, idcs) ->
          Hash_set.add visited tn;
          let ident = get_ident tn in
          let prefix, postfix = B.convert_precision ~from:(Lazy.force tn.prec) ~to_:prec in
          fprintf ppf "@[<2>%s%s[%a@;<0 -2>]%s@]" prefix ident pp_array_offset
            (idcs, Lazy.force tn.dims)
            postfix
      | Constant c ->
          let prefix, postfix = B.convert_precision ~from:Ops.double ~to_:prec in
          let prefix, postfix =
            if String.is_empty prefix && Float.(c < 0.0) then ("(", ")" ^ postfix)
            else (prefix, postfix)
          in
          fprintf ppf "%s%.16g%s" prefix c postfix
      | Embed_index idx ->
          let prefix, postfix = B.convert_precision ~from:Ops.double ~to_:prec in
          fprintf ppf "%s%a%s" prefix pp_index_axis idx postfix
      | Binop (Arg1, v1, _v2) -> loop ppf v1
      | Binop (Arg2, _v1, v2) -> loop ppf v2
      | Binop (op, v1, v2) ->
          let prefix, infix, postfix = B.binop_syntax prec op in
          fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix loop v1 infix loop v2 postfix
      | Unop (op, v) ->
          let prefix, postfix = B.unop_syntax prec op in
          fprintf ppf "@[<1>%s%a@]%s" prefix loop v postfix
    and debug_float (prec : Ops.prec) (value : Low_level.float_t) : string * 'a list =
      let loop = debug_float prec in
      match value with
      | Local_scope { id; _ } ->
          (* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug
             logs. *)
          loop @@ Get_local id
      | Get_local id ->
          let prefix, postfix = B.convert_precision ~from:(Lazy.force id.tn.prec) ~to_:prec in
          let v = String.concat [ prefix; "v"; Int.to_string id.scope_id; postfix ] in
          (v ^ "{=%g}", [ `Value v ])
      | Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) ->
          let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
          let prefix, postfix = B.convert_precision ~from:(Lazy.force tn.prec) ~to_:prec in
          let dims = Lazy.force tn.dims in
          let v =
            sprintf "@[<2>%smerge_buffer[%s@;<0 -2>]%s@]" prefix
              (array_offset_to_string (idcs, dims))
              postfix
          in
          ( String.concat [ prefix; "merge_buffer[%u]"; postfix; "{=%g}" ],
            [ `Accessor (idcs, dims); `Value v ] )
      | Get_global _ -> failwith "Exec_as_cuda: Get_global / FFI NOT IMPLEMENTED YET"
      | Get (tn, idcs) ->
          let dims = Lazy.force tn.dims in
          let ident = get_ident tn in
          let prefix, postfix = B.convert_precision ~from:(Lazy.force tn.prec) ~to_:prec in
          let v =
            sprintf "@[<2>%s%s[%s@;<0 -2>]%s@]" prefix ident
              (array_offset_to_string (idcs, dims))
              postfix
          in
          ( String.concat [ prefix; ident; "[%u]"; postfix; "{=%g}" ],
            [ `Accessor (idcs, dims); `Value v ] )
      | Constant c ->
          let prefix, postfix = B.convert_precision ~from:Ops.double ~to_:prec in
          (prefix ^ Float.to_string c ^ postfix, [])
      | Embed_index (Fixed_idx i) -> (Int.to_string i, [])
      | Embed_index (Iterator s) -> (Indexing.symbol_ident s, [])
      | Binop (Arg1, v1, _v2) -> loop v1
      | Binop (Arg2, _v1, v2) -> loop v2
      | Binop (op, v1, v2) ->
          let prefix, infix, postfix = B.binop_syntax prec op in
          let v1, idcs1 = loop v1 in
          let v2, idcs2 = loop v2 in
          (String.concat [ prefix; v1; infix; " "; v2; postfix ], idcs1 @ idcs2)
      | Unop (op, v) ->
          let prefix, postfix = B.unop_syntax prec op in
          let v, idcs = loop v in
          (String.concat [ prefix; v; postfix ], idcs)
    in
    pp_ll ppf llc

  let%diagn_sexp compile_proc ~name ppf idx_params ~is_global
      Low_level.{ traced_store; llc; merge_node } =
    let open Stdlib.Format in
    let params : (string * param_source) list =
      (* Preserve the order in the hashtable, so it's the same as e.g. in compile_globals. *)
      List.rev
      @@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:node params ->
             (* A rough approximation to the type Gccjit_backend.mem_properties. *)
             let backend_info =
               Sexp.Atom
                 (if B.is_in_context node then "From_context"
                  else if Hash_set.mem is_global tn then "Constant_from_host"
                  else if Tn.is_virtual_force tn 3331 then "Virtual"
                  else "Local_only")
             in
             if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then
               tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
             if B.is_in_context node && not (Hash_set.mem is_global tn) then
               (B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params
             else params)
    in
    let idx_params =
      List.map idx_params ~f:(fun s ->
          ("int " ^ Indexing.symbol_ident s.Indexing.static_symbol, Static_idx s))
    in
    let log_file =
      if Utils.debug_log_from_routines () then
        [
          ((if B.logs_to_stdout then "int log_id" else "const char* log_file_name"), Log_file_name);
        ]
      else []
    in
    let merge_param =
      Option.(
        to_list
        @@ map merge_node ~f:(fun tn ->
               ("const " ^ B.typ_of_prec (Lazy.force tn.prec) ^ " *merge_buffer", Merge_buffer)))
    in
    let params = log_file @ merge_param @ idx_params @ params in
    let params =
      List.sort params ~compare:(fun (p1_name, _) (p2_name, _) -> compare_string p1_name p2_name)
    in
    fprintf ppf "@[<v 2>@[<hv 4>%s%svoid %s(@,@[<hov 0>%a@]@;<0 -4>)@] {@ " B.main_kernel_prefix
      (if String.is_empty B.main_kernel_prefix then "" else " ")
      name
      (pp_print_list ~pp_sep:pp_comma pp_print_string)
    @@ List.map ~f:fst params;
    if not (String.is_empty B.kernel_prep_line) then fprintf ppf "%s@ " B.kernel_prep_line;
    (* FIXME: we should also close the file. *)
    if (not (List.is_empty log_file)) && not B.logs_to_stdout then
      fprintf ppf {|FILE* log_file = fopen(log_file_name, "w");@ |};
    if Utils.debug_log_from_routines () then (
      fprintf ppf "/* Debug initial parameter state. */@ ";
      List.iter
        ~f:(function
          | p_name, Merge_buffer ->
              if B.logs_to_stdout then
                fprintf ppf
                  {|@[<7>printf(@[<h>"%s%%d: %s = %%p\n",@] log_id, (void*)merge_buffer);@]@ |}
                  !Utils.captured_log_prefix p_name
              else
                fprintf ppf
                  {|@[<7>fprintf(log_file,@ @[<h>"%s = %%p\n",@] (void*)merge_buffer);@]@ |} p_name
          | _, Log_file_name -> ()
          | p_name, Param_ptr tn ->
              if B.logs_to_stdout then
                fprintf ppf {|@[<7>printf(@[<h>"%s%%d: %s = %%p\n",@] log_id, (void*)%s);@]@ |}
                  !Utils.captured_log_prefix p_name
                @@ get_ident tn
              else
                fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"%s = %%p\n",@] (void*)%s);@]@ |} p_name
                @@ get_ident tn
          | p_name, Static_idx s ->
              if B.logs_to_stdout then
                fprintf ppf {|@[<7>printf(@[<h>"%s%%d: %s = %%d\n",@] log_id, %s);@]@ |}
                  !Utils.captured_log_prefix p_name
                @@ Indexing.symbol_ident s.Indexing.static_symbol
              else
                fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"%s = %%d\n",@] %s);@]@ |} p_name
                @@ Indexing.symbol_ident s.Indexing.static_symbol)
        params);
    fprintf ppf "/* Local declarations and initialization. */@ ";
    Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
        if not (Tn.is_virtual_force tn 333 || B.is_in_context node || Hash_set.mem is_global tn)
        then
          fprintf ppf "%s %s[%d]%s;@ "
            (B.typ_of_prec @@ Lazy.force tn.prec)
            (get_ident tn) (Tn.num_elems tn)
            (if node.zero_initialized then " = {0}" else "")
        else if (not (Tn.is_virtual_force tn 333)) && node.zero_initialized then pp_zero_out ppf tn);
    fprintf ppf "@,/* Main logic. */@ ";
    compile_main ~traced_store ppf llc;
    fprintf ppf "@;<0 -2>}@]@.";
    params
end

let check_merge_buffer ~merge_buffer ~code_node =
  let device_node = Option.map !merge_buffer ~f:snd in
  let name = function Some tn -> Tn.debug_name tn | None -> "none" in
  match (device_node, code_node) with
  | _, None -> ()
  | Some actual, Some expected when Tn.equal actual expected -> ()
  | _ ->
      raise
      @@ Utils.User_error
           ("Merge buffer mismatch, on device: " ^ name device_node ^ ", expected by code: "
          ^ name code_node)
OCaml

Innovation. Community. Security.