package rune
Automatic differentiation and JIT compilation for OCaml
Install
Dune Dependency
Authors
Maintainers
Sources
raven-1.0.0.alpha0.tbz
sha256=a9a8a9787f8250337187bb7b21cb317c41bfd2ecf08bcfe0ab407c7b6660764d
sha512=fe13cf257c487e41efe2967be147d80fa94bac8996d3aab2b8fd16f0bbbd108c15e0e58c025ec9bf294d4a0d220ca2ba00c3b1b42fa2143f758c5f0ee4c15782
doc/src/rune.jit/scheduler.ml.html
Source file scheduler.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
(* scheduler.ml *) open Ir (* ───── kernel-spec record ───── *) type kernel_spec_t = { name : string; nodes : any_node list; inputs : Var.t list; (* HL vars *) outputs : Var.t list; (* HL vars *) vars_metadata : (Var.t, var_metadata) Hashtbl.t; } (* ───── helpers on nodes ───── *) let get_node_input_vars (Any_Node node) : Var.t list = match node with | Placeholder _ | Const_Scalar _ | Buffer _ | Vconst _ -> [] | Binop { a_var; b_var; _ } -> [ a_var; b_var ] | Unary { in_var; _ } | Reduce_Axis { in_var; _ } | Expand { in_var; _ } | Reshape { in_var; _ } | Permute { in_var; _ } | Pad { in_var; _ } | Shrink { in_var; _ } | Flip { in_var; _ } | Cast { in_var; _ } | Bitcast { in_var; _ } | Contiguous { in_var; _ } | Copy { in_var; _ } | View { in_var; _ } | Valid { in_var; _ } | Detach { in_var; _ } | Contiguous_Backward { in_var; _ } | Fuse { in_var; _ } -> [ in_var ] | Ternary { a_var; b_var; c_var; _ } -> [ a_var; b_var; c_var ] | Cat { in_vars; _ } -> Array.to_list in_vars | Vectorize { in_vars; _ } -> Array.to_list in_vars | Contract { in_vars; _ } -> Array.to_list in_vars | Assign { target_var; updates; _ } -> target_var :: List.map (fun (src, _, _) -> src) (Array.to_list updates) | Threefry { ctr_var; key_var; _ } -> [ ctr_var; key_var ] | Gather { src_var; indices_var; _ } -> [ src_var; indices_var ] | Scatter { indices_var; updates_var; _ } -> [ indices_var; updates_var ] | Index { in_var; idx_var; valid_var; _ } -> in_var :: idx_var :: (match valid_var with None -> [] | Some v -> [ v ]) | Gep { in_var; _ } -> [ in_var ] | Wmma { a_var; b_var; c_var; _ } -> [ a_var; b_var; c_var ] | Define_Var _ | Unique _ | Device _ -> [] | Bind { sym_var; _ } -> [ sym_var ] | Buffer_View { buffer_var; _ } -> [ buffer_var ] | Multi { device_vars; _ } -> Array.to_list device_vars | Unroll { loop_var; _ } -> [ loop_var ] | Sink { deps; _ } -> Array.to_list deps | Kernel { input_vars; _ } -> Array.to_list input_vars | Custom { in_vars; _ } -> Array.to_list in_vars | Noop { in_var; _ } -> ( match in_var with None -> [] | Some v -> [ v ]) let get_node_output_var (Any_Node node) : Var.t = match node with | Placeholder { out_var; _ } | Const_Scalar { out_var; _ } | Vconst { out_var; _ } | Buffer { out_var; _ } | Buffer_View { out_var; _ } | Binop { out_var; _ } | Unary { out_var; _ } | Ternary { out_var; _ } | Reduce_Axis { out_var; _ } | Expand { out_var; _ } | Reshape { out_var; _ } | Permute { out_var; _ } | Pad { out_var; _ } | Shrink { out_var; _ } | Flip { out_var; _ } | Cat { out_var; _ } | Cast { out_var; _ } | Bitcast { out_var; _ } | Contiguous { out_var; _ } | Copy { out_var; _ } | Assign { out_var; _ } | Threefry { out_var; _ } | Gather { out_var; _ } | Scatter { out_var; _ } | View { out_var; _ } | Valid { out_var; _ } | Index { out_var; _ } | Gep { out_var; _ } | Vectorize { out_var; _ } | Wmma { out_var; _ } | Define_Var { out_var; _ } | Bind { out_var; _ } | Detach { out_var; _ } | Contiguous_Backward { out_var; _ } | Multi { out_var; _ } | Fuse { out_var; _ } | Unroll { out_var; _ } | Contract { out_var; _ } | Kernel { out_var; _ } | Unique { out_var; _ } | Device { out_var; _ } | Custom { out_var; _ } | Noop { out_var; _ } -> out_var | Sink _ -> Var.fresh () (* Sink has no output var, create dummy *) let is_boundary_node (Any_Node node) = match node with | Reduce_Axis _ | Buffer _ | Cat _ | Scatter _ | Assign _ | Wmma _ | Multi _ | Kernel _ | Sink _ -> true | _ -> false let is_fusible_elementwise (Any_Node node) = match node with | Binop _ | Unary _ | Ternary _ | Const_Scalar _ | Vconst _ | Expand _ | Reshape _ | Permute _ | Placeholder _ | Pad _ | Shrink _ | Flip _ | Cast _ | Bitcast _ | Contiguous _ | Copy _ | View _ | Valid _ | Detach _ | Contiguous_Backward _ | Fuse _ | Noop _ -> true | _ -> false (* ───── main scheduling pass ───── *) let schedule (graph : Ir.graph_t) : kernel_spec_t list = let scheduled = ref [] in let current = ref [] in let kidx = ref 0 in (* Map var -> list of nodes that read it (for output detection) *) let var_consumers : (Var.t, any_node list) Hashtbl.t = Hashtbl.create (List.length graph.nodes) in List.iter (fun consumer -> List.iter (fun v -> let lst = Option.value ~default:[] (Hashtbl.find_opt var_consumers v) in Hashtbl.replace var_consumers v (consumer :: lst)) (get_node_input_vars consumer)) graph.nodes; let flush_current () = if !current <> [] then ( let nodes = List.rev !current in let produced = List.filter (function | Ir.Any_Node (Placeholder _) | Ir.Any_Node (Sink _) -> false | _ -> true) nodes |> List.map get_node_output_var |> Var.Set.of_list in let node_inputs = List.concat_map get_node_input_vars nodes |> Var.Set.of_list in let inputs = Var.Set.diff node_inputs produced in let placeholder_inputs = List.filter_map (function | Ir.Any_Node (Placeholder { out_var; _ }) -> Some out_var | _ -> None) nodes |> Var.Set.of_list in let inputs = Var.Set.union inputs placeholder_inputs in (* vars metadata used inside kernel *) let vars_md = Hashtbl.create 16 in List.iter (fun (Any_Node n) -> Var.Set.iter (fun v -> match Hashtbl.find_opt graph.vars_metadata v with | Some m -> Hashtbl.replace vars_md v m | None -> ()) (Var.Set.of_list (get_node_output_var (Any_Node n) :: get_node_input_vars (Any_Node n)))) nodes; (* outputs: graph outputs OR consumed outside kernel *) let outputs = Var.Set.filter (fun v -> List.mem v graph.output_vars || match Hashtbl.find_opt var_consumers v with | None -> false | Some readers -> List.exists (fun n -> not (List.memq n nodes)) readers) produced |> Var.Set.elements in scheduled := { name = Printf.sprintf "kernel_%d" !kidx; nodes; inputs = Var.Set.elements inputs; outputs; vars_metadata = vars_md; } :: !scheduled; incr kidx; current := []) in List.iter (fun node -> let can_fuse = match !current with | [] -> true | _ -> (not (is_boundary_node node)) && is_fusible_elementwise node && (* check all inputs are either kernel inputs or already produced *) let produced = List.map get_node_output_var !current |> Var.Set.of_list in List.for_all (fun v -> List.mem v graph.input_vars || Var.Set.mem v produced) (get_node_input_vars node) in if can_fuse then current := node :: !current else ( flush_current (); current := [ node ])) graph.nodes; flush_current (); List.rev !scheduled
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>