package nx

  1. Overview
  2. Docs

Source file parallel.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
type task = { start_idx : int; end_idx : int; compute : int -> int -> unit }
type _ Effect.t += WaitCompletion : int -> unit Effect.t

type pool = {
  num_workers : int;
  task_assignments : task option array;
  completed : int Atomic.t;
  generation : int Atomic.t;
  mutex : Mutex.t;
  work_available : Condition.t;
}

let current_pool = ref None

let setup_pool () =
  let num_workers = Domain.recommended_domain_count () - 1 in
  let task_assignments = Array.make num_workers None in
  let completed = Atomic.make 0 in
  let generation = Atomic.make 0 in
  let mutex = Mutex.create () in
  let work_available = Condition.create () in
  let pool =
    {
      num_workers;
      task_assignments;
      completed;
      generation;
      mutex;
      work_available;
    }
  in
  let worker id =
    let last_gen = ref (-1) in
    while true do
      Mutex.lock pool.mutex;
      let current_gen = Atomic.get pool.generation in
      while pool.task_assignments.(id) = None && !last_gen = current_gen do
        Condition.wait pool.work_available pool.mutex
      done;
      let current_gen = Atomic.get pool.generation in
      if pool.task_assignments.(id) <> None then (
        let task = Option.get pool.task_assignments.(id) in
        pool.task_assignments.(id) <- None;
        last_gen := current_gen;
        Mutex.unlock pool.mutex;
        (try task.compute task.start_idx task.end_idx
         with exn ->
           Printf.eprintf "Worker %d: Exception in task: %s\n" id
             (Printexc.to_string exn);
           flush stderr);
        Atomic.incr pool.completed)
      else (
        (* New generation without task for us, loop back *)
        last_gen := current_gen;
        Mutex.unlock pool.mutex)
    done
  in
  for i = 0 to num_workers - 1 do
    ignore (Domain.spawn (fun () -> worker i))
  done;
  pool

let get_or_setup_pool () =
  match !current_pool with
  | Some pool -> pool
  | None ->
      let pool = setup_pool () in
      current_pool := Some pool;
      pool

let get_num_domains pool = pool.num_workers + 1

let run pool f =
  let open Effect.Deep in
  try_with f ()
    Effect.
      {
        effc =
          (fun (type a) (e : a t) ->
            match e with
            | WaitCompletion target ->
                Some
                  (fun (k : (a, unit) continuation) ->
                    let rec wait () =
                      if Atomic.get pool.completed >= target then continue k ()
                      else (
                        Domain.cpu_relax ();
                        wait ())
                    in
                    wait ())
            | _ -> None);
      }

let parallel_execute pool tasks =
  run pool (fun () ->
      let num_tasks = Array.length tasks in
      if num_tasks <> get_num_domains pool then
        invalid_arg
          "parallel_execute: number of tasks must equal num_workers + 1";
      Atomic.set pool.completed 0;
      Mutex.lock pool.mutex;
      Atomic.incr pool.generation;
      for i = 0 to pool.num_workers - 1 do
        pool.task_assignments.(i) <- Some tasks.(i)
      done;
      Condition.broadcast pool.work_available;
      Mutex.unlock pool.mutex;
      let main_task = tasks.(pool.num_workers) in
      main_task.compute main_task.start_idx main_task.end_idx;
      Effect.perform (WaitCompletion pool.num_workers))

let parallel_for pool start end_ compute_chunk =
  let total_iterations = end_ - start + 1 in
  if total_iterations <= 0 then ()
  else if total_iterations <= 1 then compute_chunk start (start + 1)
  else
    let total_domains = get_num_domains pool in
    let chunk_size = total_iterations / total_domains in
    let remainder = total_iterations mod total_domains in
    let tasks =
      Array.init total_domains (fun d ->
          let start_idx = start + (d * chunk_size) + min d remainder in
          let len = chunk_size + if d < remainder then 1 else 0 in
          let end_idx = start_idx + len in
          { start_idx; end_idx; compute = compute_chunk })
    in
    parallel_execute pool tasks

let parallel_for_reduce pool start end_ body reduce init =
  let total_domains = get_num_domains pool in
  let results = Array.make total_domains init in
  let chunk_size = (end_ - start + 1) / total_domains in
  let remainder = (end_ - start + 1) mod total_domains in
  let tasks =
    Array.init total_domains (fun d ->
        let start_idx = start + (d * chunk_size) + min d remainder in
        let len = chunk_size + if d < remainder then 1 else 0 in
        let end_idx = start_idx + len in
        let compute _ _ =
          (* Ignore args since start_idx and end_idx are captured *)
          let partial_result = body start_idx end_idx in
          results.(d) <- partial_result
        in
        { start_idx; end_idx; compute })
  in
  parallel_execute pool tasks;
  let final_result = ref init in
  for i = 0 to total_domains - 1 do
    final_result := reduce !final_result results.(i)
  done;
  !final_result
OCaml

Innovation. Community. Security.