Source file backend_utils.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
open Base
module Lazy = Utils.Lazy
module Debug_runtime = Utils.Debug_runtime
let _get_local_debug_runtime = Utils._get_local_debug_runtime
[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
module Types = struct
type 'context routine = {
context : 'context;
schedule : Tnode.task;
bindings : Indexing.lowered_bindings;
name : string;
}
[@@deriving sexp_of]
type config = Physical_devices_only | For_parallel_copying | Most_parallel_devices
[@@deriving equal, sexp, variants]
type merge_buffer_use = No | Streaming | Copy [@@deriving equal, sexp]
type param_source =
| Log_file_name
| Merge_buffer
| Param_ptr of Tnode.t
| Static_idx of Indexing.static_symbol
[@@deriving sexp_of]
end
module Tn = Tnode
module C_syntax (B : sig
val for_lowereds : Low_level.optimized array
type ctx_array
val opt_ctx_arrays : ctx_array Map.M(Tnode).t option
val hardcoded_context_ptr : (ctx_array -> string) option
val is_in_context : Low_level.traced_array -> bool
val host_ptrs_for_readonly : bool
val logs_to_stdout : bool
val main_kernel_prefix : string
val kernel_prep_line : string
val include_lines : string list
val typ_of_prec : Ops.prec -> string
val binop_syntax : Ops.prec -> Ops.binop -> string * string * string
val unop_syntax : Ops.prec -> Ops.unop -> string * string
val convert_precision : from:Ops.prec -> to_:Ops.prec -> string * string
end) =
struct
open Types
let get_ident =
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.for_lowereds ~f:(fun l -> l.llc)
let pp_zero_out ppf tn =
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn
let pp_comma ppf () = Stdlib.Format.fprintf ppf ",@ "
let pp_index ppf sym = Stdlib.Format.fprintf ppf "%s" @@ Indexing.symbol_ident sym
let pp_index_axis ppf = function
| Indexing.Iterator it -> pp_index ppf it
| Fixed_idx i when i < 0 -> Stdlib.Format.fprintf ppf "(%d)" i
| Fixed_idx i -> Stdlib.Format.fprintf ppf "%d" i
let pp_array_offset ppf (idcs, dims) =
let open Stdlib.Format in
assert (not @@ Array.is_empty idcs);
for _ = 0 to Array.length idcs - 3 do
fprintf ppf "@[<1>("
done;
for i = 0 to Array.length idcs - 1 do
let dim = dims.(i) in
if i = 0 then fprintf ppf "%a" pp_index_axis idcs.(i)
else if i = Array.length idcs - 1 then fprintf ppf " * %d + %a" dim pp_index_axis idcs.(i)
else fprintf ppf " * %d +@ %a@;<0 -1>)@]" dim pp_index_axis idcs.(i)
done
let array_offset_to_string (idcs, dims) =
let b = Buffer.create 32 in
let ppf = Stdlib.Format.formatter_of_buffer b in
pp_array_offset ppf (idcs, dims);
Stdlib.Format.pp_print_flush ppf ();
Buffer.contents b
let%diagn_sexp compile_globals ppf =
let open Stdlib.Format in
let is_global = Hash_set.create (module Tn) in
fprintf ppf {|@[<v 0>%a@,/* Global declarations. */@,|} (pp_print_list pp_print_string)
B.include_lines;
Array.iter B.for_lowereds ~f:(fun l ->
Hashtbl.iter l.Low_level.traced_store ~f:(fun (node : Low_level.traced_array) ->
if not @@ Hash_set.mem is_global node.tn then
let in_ctx : bool = B.is_in_context node in
let ctx_ptr = B.hardcoded_context_ptr in
let mem : (Tn.memory_mode * int) option = node.tn.memory_mode in
match
(in_ctx, ctx_ptr, B.opt_ctx_arrays, B.host_ptrs_for_readonly, mem, node.read_only)
with
| true, Some get_ptr, Some ctx_arrays, _, _, _ ->
let ident = get_ident node.tn in
let ctx_array =
Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays node.tn
in
fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array;
Hash_set.add is_global node.tn
| false, _, _, true, Some (Hosted _, _), true ->
let nd = Option.value_exn ~here:[%here] @@ Lazy.force node.tn.array in
fprintf ppf "#define %s (%s)@," (get_ident node.tn) (Ndarray.c_ptr_to_string nd);
Hash_set.add is_global node.tn
| _ -> ()));
fprintf ppf "@,@]";
is_global
let compile_main ~traced_store ppf llc : unit =
let open Stdlib.Format in
let visited = Hash_set.create (module Tn) in
let rec pp_ll ppf c : unit =
match c with
| Low_level.Noop -> ()
| Seq (c1, c2) ->
fprintf ppf "@[<v 0>%a@]" (pp_print_list pp_ll)
(List.filter [ c1; c2 ] ~f:(function Noop -> false | _ -> true))
| For_loop { index = i; from_; to_; body; trace_it = _ } ->
fprintf ppf "@[<2>for (int@ %a = %d;@ %a <= %d;@ ++%a) {@ " pp_index i from_ pp_index i
to_ pp_index i;
if Utils.debug_log_from_routines () then
if B.logs_to_stdout then
fprintf ppf {|printf(@[<h>"%s%%d: index %a = %%d\n",@] log_id, %a);@ |}
!Utils.captured_log_prefix pp_index i pp_index i
else
fprintf ppf {|fprintf(log_file,@ @[<h>"index %a = %%d\n",@] %a);@ |} pp_index i
pp_index i;
fprintf ppf "%a@;<1 -2>}@]@," pp_ll body
| Zero_out tn ->
let traced = Low_level.(get_node traced_store tn) in
if Hash_set.mem visited tn then pp_zero_out ppf tn else assert traced.zero_initialized
| Set { tn; idcs; llv; debug } ->
Hash_set.add visited tn;
let ident = get_ident tn in
let dims = Lazy.force tn.dims in
let loop_f = pp_float @@ Lazy.force tn.prec in
let loop_debug_f = debug_float @@ Lazy.force tn.prec in
let num_closing_braces = pp_top_locals ppf llv in
let num_typ = B.typ_of_prec @@ Lazy.force tn.prec in
if Utils.debug_log_from_routines () then (
fprintf ppf "@[<2>{@ @[<2>%s new_set_v =@ %a;@]@ " num_typ loop_f llv;
let v_code, v_idcs = loop_debug_f llv in
let pp_args =
pp_print_list @@ fun ppf -> function
| `Accessor idx ->
pp_comma ppf ();
pp_array_offset ppf idx
| `Value v ->
pp_comma ppf ();
pp_print_string ppf v
in
let offset = (idcs, dims) in
if B.logs_to_stdout then (
fprintf ppf {|@[<7>printf(@[<h>"%s%%d: # %s\n", log_id@]);@]@ |}
!Utils.captured_log_prefix
(String.substr_replace_all debug ~pattern:"\n" ~with_:"$");
fprintf ppf
{|@[<7>printf(@[<h>"%s%%d: %s[%%u]{=%%g} = %%g = %s\n",@]@ log_id,@ %a,@ %s[%a],@ new_set_v%a);@]@ |}
!Utils.captured_log_prefix ident v_code pp_array_offset offset ident pp_array_offset
offset pp_args v_idcs)
else (
fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"# %s\n"@]);@]@ |}
(String.substr_replace_all debug ~pattern:"\n" ~with_:"$");
fprintf ppf
{|@[<7>fprintf(log_file,@ @[<h>"%s[%%u]{=%%g} = %%g = %s\n",@]@ %a,@ %s[%a],@ new_set_v%a);@]@ |}
ident v_code pp_array_offset offset ident pp_array_offset offset pp_args v_idcs);
if not B.logs_to_stdout then fprintf ppf "fflush(log_file);@ ";
fprintf ppf "@[<2>%s[@,%a] =@ new_set_v;@]@;<1 -2>}@]@ " ident pp_array_offset
(idcs, dims))
else
fprintf ppf "@[<2>%s[@,%a] =@ %a;@]@ " ident pp_array_offset (idcs, dims) loop_f llv;
for _ = 1 to num_closing_braces do
fprintf ppf "@]@ }@,"
done
| Comment message ->
if Utils.debug_log_from_routines () then
if B.logs_to_stdout then
fprintf ppf {|printf(@[<h>"%s%%d: COMMENT: %s\n",@] log_id);@ |}
!Utils.captured_log_prefix
(String.substr_replace_all ~pattern:"%" ~with_:"%%" message)
else
fprintf ppf {|fprintf(log_file,@ @[<h>"COMMENT: %s\n"@]);@ |}
(String.substr_replace_all ~pattern:"%" ~with_:"%%" message)
else fprintf ppf "/* %s */@ " message
| Staged_compilation callback -> callback ()
| Set_local (Low_level.{ scope_id; tn = { prec; _ } }, value) ->
let num_closing_braces = pp_top_locals ppf value in
fprintf ppf "@[<2>v%d =@ %a;@]" scope_id (pp_float @@ Lazy.force prec) value;
for _ = 1 to num_closing_braces do
fprintf ppf "@]@ }@,"
done
and pp_top_locals ppf (vcomp : Low_level.float_t) : int =
match vcomp with
| Local_scope { id = { scope_id = i; tn = { prec; _ } }; body; orig_indices = _ } ->
let num_typ = B.typ_of_prec @@ Lazy.force prec in
fprintf ppf "@[<2>{@ %s v%d = 0;@ " num_typ i;
pp_ll ppf body;
pp_print_space ppf ();
1
| Get_local _ | Get_global _ | Get _ | Constant _ | Embed_index _ -> 0
| Binop (Arg1, v1, _v2) -> pp_top_locals ppf v1
| Binop (Arg2, _v1, v2) -> pp_top_locals ppf v2
| Binop (_, v1, v2) -> pp_top_locals ppf v1 + pp_top_locals ppf v2
| Unop (_, v) -> pp_top_locals ppf v
and pp_float (prec : Ops.prec) ppf value =
let loop = pp_float prec in
match value with
| Local_scope { id; _ } ->
loop ppf @@ Get_local id
| Get_local id ->
let prefix, postfix = B.convert_precision ~from:(Lazy.force id.tn.prec) ~to_:prec in
fprintf ppf "%sv%d%s" prefix id.scope_id postfix
| Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) ->
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
let prefix, postfix = B.convert_precision ~from:(Lazy.force tn.prec) ~to_:prec in
fprintf ppf "@[<2>%smerge_buffer[%a@;<0 -2>]%s@]" prefix pp_array_offset
(idcs, Lazy.force tn.dims)
postfix
| Get_global _ -> failwith "C_syntax: Get_global / FFI NOT IMPLEMENTED YET"
| Get (tn, idcs) ->
Hash_set.add visited tn;
let ident = get_ident tn in
let prefix, postfix = B.convert_precision ~from:(Lazy.force tn.prec) ~to_:prec in
fprintf ppf "@[<2>%s%s[%a@;<0 -2>]%s@]" prefix ident pp_array_offset
(idcs, Lazy.force tn.dims)
postfix
| Constant c ->
let prefix, postfix = B.convert_precision ~from:Ops.double ~to_:prec in
let prefix, postfix =
if String.is_empty prefix && Float.(c < 0.0) then ("(", ")" ^ postfix)
else (prefix, postfix)
in
fprintf ppf "%s%.16g%s" prefix c postfix
| Embed_index idx ->
let prefix, postfix = B.convert_precision ~from:Ops.double ~to_:prec in
fprintf ppf "%s%a%s" prefix pp_index_axis idx postfix
| Binop (Arg1, v1, _v2) -> loop ppf v1
| Binop (Arg2, _v1, v2) -> loop ppf v2
| Binop (op, v1, v2) ->
let prefix, infix, postfix = B.binop_syntax prec op in
fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix loop v1 infix loop v2 postfix
| Unop (op, v) ->
let prefix, postfix = B.unop_syntax prec op in
fprintf ppf "@[<1>%s%a@]%s" prefix loop v postfix
and debug_float (prec : Ops.prec) (value : Low_level.float_t) : string * 'a list =
let loop = debug_float prec in
match value with
| Local_scope { id; _ } ->
loop @@ Get_local id
| Get_local id ->
let prefix, postfix = B.convert_precision ~from:(Lazy.force id.tn.prec) ~to_:prec in
let v = String.concat [ prefix; "v"; Int.to_string id.scope_id; postfix ] in
(v ^ "{=%g}", [ `Value v ])
| Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) ->
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
let prefix, postfix = B.convert_precision ~from:(Lazy.force tn.prec) ~to_:prec in
let dims = Lazy.force tn.dims in
let v =
sprintf "@[<2>%smerge_buffer[%s@;<0 -2>]%s@]" prefix
(array_offset_to_string (idcs, dims))
postfix
in
( String.concat [ prefix; "merge_buffer[%u]"; postfix; "{=%g}" ],
[ `Accessor (idcs, dims); `Value v ] )
| Get_global _ -> failwith "Exec_as_cuda: Get_global / FFI NOT IMPLEMENTED YET"
| Get (tn, idcs) ->
let dims = Lazy.force tn.dims in
let ident = get_ident tn in
let prefix, postfix = B.convert_precision ~from:(Lazy.force tn.prec) ~to_:prec in
let v =
sprintf "@[<2>%s%s[%s@;<0 -2>]%s@]" prefix ident
(array_offset_to_string (idcs, dims))
postfix
in
( String.concat [ prefix; ident; "[%u]"; postfix; "{=%g}" ],
[ `Accessor (idcs, dims); `Value v ] )
| Constant c ->
let prefix, postfix = B.convert_precision ~from:Ops.double ~to_:prec in
(prefix ^ Float.to_string c ^ postfix, [])
| Embed_index (Fixed_idx i) -> (Int.to_string i, [])
| Embed_index (Iterator s) -> (Indexing.symbol_ident s, [])
| Binop (Arg1, v1, _v2) -> loop v1
| Binop (Arg2, _v1, v2) -> loop v2
| Binop (op, v1, v2) ->
let prefix, infix, postfix = B.binop_syntax prec op in
let v1, idcs1 = loop v1 in
let v2, idcs2 = loop v2 in
(String.concat [ prefix; v1; infix; " "; v2; postfix ], idcs1 @ idcs2)
| Unop (op, v) ->
let prefix, postfix = B.unop_syntax prec op in
let v, idcs = loop v in
(String.concat [ prefix; v; postfix ], idcs)
in
pp_ll ppf llc
let%diagn_sexp compile_proc ~name ppf idx_params ~is_global
Low_level.{ traced_store; llc; merge_node } =
let open Stdlib.Format in
let params : (string * param_source) list =
List.rev
@@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:node params ->
let backend_info =
Sexp.Atom
(if B.is_in_context node then "From_context"
else if Hash_set.mem is_global tn then "Constant_from_host"
else if Tn.is_virtual_force tn 3331 then "Virtual"
else "Local_only")
in
if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
if B.is_in_context node && not (Hash_set.mem is_global tn) then
(B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params
else params)
in
let idx_params =
List.map idx_params ~f:(fun s ->
("int " ^ Indexing.symbol_ident s.Indexing.static_symbol, Static_idx s))
in
let log_file =
if Utils.debug_log_from_routines () then
[
((if B.logs_to_stdout then "int log_id" else "const char* log_file_name"), Log_file_name);
]
else []
in
let merge_param =
Option.(
to_list
@@ map merge_node ~f:(fun tn ->
("const " ^ B.typ_of_prec (Lazy.force tn.prec) ^ " *merge_buffer", Merge_buffer)))
in
let params = log_file @ merge_param @ idx_params @ params in
let params =
List.sort params ~compare:(fun (p1_name, _) (p2_name, _) -> compare_string p1_name p2_name)
in
fprintf ppf "@[<v 2>@[<hv 4>%s%svoid %s(@,@[<hov 0>%a@]@;<0 -4>)@] {@ " B.main_kernel_prefix
(if String.is_empty B.main_kernel_prefix then "" else " ")
name
(pp_print_list ~pp_sep:pp_comma pp_print_string)
@@ List.map ~f:fst params;
if not (String.is_empty B.kernel_prep_line) then fprintf ppf "%s@ " B.kernel_prep_line;
if (not (List.is_empty log_file)) && not B.logs_to_stdout then
fprintf ppf {|FILE* log_file = fopen(log_file_name, "w");@ |};
if Utils.debug_log_from_routines () then (
fprintf ppf "/* Debug initial parameter state. */@ ";
List.iter
~f:(function
| p_name, Merge_buffer ->
if B.logs_to_stdout then
fprintf ppf
{|@[<7>printf(@[<h>"%s%%d: %s = %%p\n",@] log_id, (void*)merge_buffer);@]@ |}
!Utils.captured_log_prefix p_name
else
fprintf ppf
{|@[<7>fprintf(log_file,@ @[<h>"%s = %%p\n",@] (void*)merge_buffer);@]@ |} p_name
| _, Log_file_name -> ()
| p_name, Param_ptr tn ->
if B.logs_to_stdout then
fprintf ppf {|@[<7>printf(@[<h>"%s%%d: %s = %%p\n",@] log_id, (void*)%s);@]@ |}
!Utils.captured_log_prefix p_name
@@ get_ident tn
else
fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"%s = %%p\n",@] (void*)%s);@]@ |} p_name
@@ get_ident tn
| p_name, Static_idx s ->
if B.logs_to_stdout then
fprintf ppf {|@[<7>printf(@[<h>"%s%%d: %s = %%d\n",@] log_id, %s);@]@ |}
!Utils.captured_log_prefix p_name
@@ Indexing.symbol_ident s.Indexing.static_symbol
else
fprintf ppf {|@[<7>fprintf(log_file,@ @[<h>"%s = %%d\n",@] %s);@]@ |} p_name
@@ Indexing.symbol_ident s.Indexing.static_symbol)
params);
fprintf ppf "/* Local declarations and initialization. */@ ";
Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
if not (Tn.is_virtual_force tn 333 || B.is_in_context node || Hash_set.mem is_global tn)
then
fprintf ppf "%s %s[%d]%s;@ "
(B.typ_of_prec @@ Lazy.force tn.prec)
(get_ident tn) (Tn.num_elems tn)
(if node.zero_initialized then " = {0}" else "")
else if (not (Tn.is_virtual_force tn 333)) && node.zero_initialized then pp_zero_out ppf tn);
fprintf ppf "@,/* Main logic. */@ ";
compile_main ~traced_store ppf llc;
fprintf ppf "@;<0 -2>}@]@.";
params
end
let check_merge_buffer ~merge_buffer ~code_node =
let device_node = Option.map !merge_buffer ~f:snd in
let name = function Some tn -> Tn.debug_name tn | None -> "none" in
match (device_node, code_node) with
| _, None -> ()
| Some actual, Some expected when Tn.equal actual expected -> ()
| _ ->
raise
@@ Utils.User_error
("Merge buffer mismatch, on device: " ^ name device_node ^ ", expected by code: "
^ name code_node)