package caisar

  1. Overview
  2. Docs

Source file nnet.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*  Copyright (C) 2024                                                    *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

open Base
module Format = Stdlib.Format
module Sys = Stdlib.Sys
module Filename = Stdlib.Filename
module Fun = Stdlib.Fun

type t = {
  n_layers : int;
  n_inputs : int;
  n_outputs : int;
  max_layer_size : int;
  layer_sizes : int list;
  min_input_values : float list option;
  max_input_values : float list option;
  mean_values : (float list * float) option;
  range_values : (float list * float) option;
  weights_biases : float list list;
  nir : Nir.Ngraph.t;
}

let to_nir weights_biases n_inputs layer_sizes =
  let open Nir in
  let create_input_node in_shape = Node.create (Input { shape = in_shape }) in
  (* weights_biases is a list of list describing the weight and biases of the NN.
   * Each inner list contains either layer_size element (weight) or one element
   * (bias) *)
  let aggregated_wb (weights_biases : float list list) layer_sizes =
    let rec slice l ~start ~stop =
      assert (stop > start);
      assert (stop < List.length weights_biases);
      match l with
      | [] -> failwith "Cannot take a slice from an empty list"
      | h :: t ->
        let tail =
          if stop = 0 then [] else slice t ~start:(start - 1) ~stop:(stop - 1)
        in
        if start > 0 then tail else h :: tail
    in
    let rec aggregs full_wb l_sizes acc acc_idx prev_size =
      match l_sizes with
      | [] -> acc
      | x :: y ->
        (*TODO: a much more efficient approach would be to consume full_wb at
         * the same time instead of slicing the full list everytime. *)
        let w_idx = acc_idx + x in
        let b_idx = w_idx + x in
        let w = List.concat @@ slice full_wb ~start:acc_idx ~stop:(w_idx - 1)
        and b = List.concat @@ slice full_wb ~start:w_idx ~stop:(b_idx - 1)
        and sh = Shape.of_array [| prev_size; x |] in
        aggregs full_wb y ((w, b, sh) :: acc) b_idx x
    in
    (* First element of layer_sizes is input size, skipping it *)
    aggregs weights_biases (List.drop layer_sizes 1) [] 0
      (List.nth_exn layer_sizes 0)
  in

  let rec traverse_wb wb acc =
    match wb with
    (* Recursively traverse weights and biases. Builds the necessary nodes and
       return the last node of a simple neural network consisting of a Matmul,
       Add and ReLU. *)
    (* Expectations: wb is a list of size num_layer containing tuple whose first
       element is the flattened list of weights and second element is the list
       of biases for each layer. *)
    | [] -> create_input_node acc
    | (weights, biases, sh_w) :: rest ->
      (* recursion will happen in the creation of the input1 node to the current
         node *)
      let input_node = traverse_wb rest acc in
      let weights_tensor =
        Nir.Gentensor.of_float_array ~shape:sh_w (Array.of_list weights)
      in
      let weights_node =
        Node.create (Node.Constant { data = weights_tensor })
      in
      let matmul_node =
        Node.create (Node.Matmul { input1 = input_node; input2 = weights_node })
      in
      let biases_tensor = Nir.Gentensor.of_float_array (Array.of_list biases) in
      let biases_node = Node.create (Node.Constant { data = biases_tensor }) in
      let add_node =
        Node.create (Add { input1 = biases_node; input2 = matmul_node })
      in
      let relu_node = Node.create (Node.ReLu { input = add_node }) in
      relu_node
  in
  let in_sh = Shape.of_list [ n_inputs ] in
  let g =
    Nir.Ngraph.create
      (traverse_wb (aggregated_wb weights_biases layer_sizes) in_sh)
  in
  g

(* NNet format handling. *)

let nnet_format_error s =
  Error (Format.sprintf "NNet format error: %s condition not satisfied." s)

(* Parse a single NNet format line: split line wrt CSV format, and convert each
   string into a number by means of converter [f]. *)
let handle_nnet_line ~f in_channel =
  List.filter_map
    ~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
    (Csv.next in_channel)

(* Skip the header part, ie comments, of the NNet format. *)
let skip_nnet_header filename in_channel =
  let exception End_of_header in
  let pos_in = ref (Stdlib.pos_in in_channel) in
  try
    while true do
      let line = Stdlib.input_line in_channel in
      if not (Str.string_match (Str.regexp "//") line 0)
      then raise End_of_header
      else pos_in := Stdlib.pos_in in_channel
    done;
    assert false
  with
  | End_of_header ->
    (* At this point we have read one line past the header part: seek back. *)
    Stdlib.seek_in in_channel !pos_in;
    Ok ()
  | End_of_file ->
    Error (Format.sprintf "NNet model not found in file '%s'." filename)

(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
let handle_nnet_basic_info in_channel =
  match handle_nnet_line ~f:Int.of_string in_channel with
  | [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
    Ok (n_layers, n_inputs, n_outputs, max_layer_size)
  | _ -> nnet_format_error "second"
  | exception End_of_file -> nnet_format_error "second"

(* Retrieve size of each layer, including inputs and outputs. *)
let handle_nnet_layer_sizes n_layers in_channel =
  try
    let layer_sizes = handle_nnet_line ~f:Int.of_string in_channel in
    if List.length layer_sizes = n_layers + 1
    then Ok layer_sizes
    else nnet_format_error "third"
  with End_of_file -> nnet_format_error "third"

(* Skip unused flag. *)
let handle_nnet_unused_flag in_channel =
  try
    let _ = Csv.next in_channel in
    Ok ()
  with End_of_file -> nnet_format_error "forth"

(* Retrive minimum values of inputs. *)
let handle_nnet_min_input_values n_inputs in_channel =
  try
    let min_input_values = handle_nnet_line ~f:Float.of_string in_channel in
    if List.length min_input_values = n_inputs
    then Ok min_input_values
    else nnet_format_error "fifth"
  with End_of_file -> nnet_format_error "fifth"

(* Retrive maximum values of inputs. *)
let handle_nnet_max_input_values n_inputs in_channel =
  try
    let max_input_values = handle_nnet_line ~f:Float.of_string in_channel in
    if List.length max_input_values = n_inputs
    then Ok max_input_values
    else nnet_format_error "sixth"
  with End_of_file -> nnet_format_error "sixth"

(* Retrieve mean values of inputs and one value for all outputs. *)
let handle_nnet_mean_values n_inputs in_channel =
  try
    let mean_values = handle_nnet_line ~f:Float.of_string in_channel in
    if List.length mean_values = n_inputs + 1
    then
      let mean_input_values, mean_output_value =
        List.split_n mean_values n_inputs
      in
      Ok (mean_input_values, List.hd_exn mean_output_value)
    else nnet_format_error "seventh"
  with End_of_file -> nnet_format_error "seventh"

(* Retrieve range values of inputs and one value for all outputs. *)
let handle_nnet_range_values n_inputs in_channel =
  try
    let range_values = handle_nnet_line ~f:Float.of_string in_channel in
    if List.length range_values = n_inputs + 1
    then
      let range_input_values, range_output_value =
        List.split_n range_values n_inputs
      in
      Ok (range_input_values, List.hd_exn range_output_value)
    else nnet_format_error "eighth"
  with End_of_file -> nnet_format_error "eighth"

(* Retrieve all layer weights and biases as appearing in the model. No special
   treatment is performed. *)
let handle_nnet_weights_and_biases in_channel =
  List.rev
    (Csv.fold_left ~init:[]
       ~f:(fun fll sl ->
         List.filter_map
           ~f:(fun s ->
             try Some (Float.of_string (String.strip s)) with _ -> None)
           sl
         :: fll)
       in_channel)

(* Retrieves [filename] NNet model metadata and weights wrt NNet format
   specification (see https://github.com/sisl/NNet for details). *)
let parse_in_channel ?(permissive = false) filename in_channel =
  let open Result in
  let ok_opt r =
    match r with
    | Ok x -> Ok (Some x)
    | Error _ as error -> if not permissive then error else Ok None
  in
  try
    skip_nnet_header filename in_channel >>= fun () ->
    let in_channel = Csv.of_channel in_channel in
    handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) ->
    handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
    handle_nnet_unused_flag in_channel >>= fun () ->
    ok_opt (handle_nnet_min_input_values n_is in_channel)
    >>= fun min_input_values ->
    ok_opt (handle_nnet_max_input_values n_is in_channel)
    >>= fun max_input_values ->
    ok_opt (handle_nnet_mean_values n_is in_channel) >>= fun mean_values ->
    ok_opt (handle_nnet_range_values n_is in_channel) >>= fun range_values ->
    let weights_biases = handle_nnet_weights_and_biases in_channel in
    Csv.close_in in_channel;
    let nir = to_nir weights_biases n_is layer_sizes in
    Ok
      {
        n_layers = n_ls;
        n_inputs = n_is;
        n_outputs = n_os;
        max_layer_size = max_l_size;
        layer_sizes;
        min_input_values;
        max_input_values;
        mean_values;
        range_values;
        weights_biases;
        nir;
      }
  with
  | Csv.Failure (_nrecord, _nfield, msg) -> Error msg
  | Sys_error s -> Error s
  | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg)

let to_nir t =
  let open Nir in
  let create_input_node in_shape = Node.create (Input { shape = in_shape }) in
  let rec traverse_wb wb acc =
    match wb with
    (* Recursively traverse weights and biases. Builds the necessary nodes and
       return the last node of a simple neural network consisting of Matmul, Add
       and ReLU. *)
    | [] -> create_input_node acc
    | weights_biases -> (
      match weights_biases with
      | [] -> failwith "Empty list"
      | _ :: [] -> failwith "Weights or biases missing."
      | weights :: biases :: rest ->
        (* recursion will happen in the creation of the input1 node to the
           current node *)
        let input_node = traverse_wb rest acc in
        let weights_tensor =
          Nir.Gentensor.of_float_array (Array.of_list weights)
        in
        let weights_node =
          Node.create (Node.Constant { data = weights_tensor })
        in
        let matmul_node =
          Node.create
            (Node.Matmul { input1 = input_node; input2 = weights_node })
        in
        let biases_tensor =
          Nir.Gentensor.of_float_array (Array.of_list biases)
        in
        let biases_node =
          Node.create (Node.Constant { data = biases_tensor })
        in
        let add_node =
          Node.create (Add { input1 = matmul_node; input2 = biases_node })
        in
        let relu_node = Node.create (Node.ReLu { input = add_node }) in
        relu_node)
  in
  let w = t.weights_biases and in_sh = Shape.of_list [ t.n_inputs ] in
  let g = Nir.Ngraph.create (traverse_wb w in_sh) in
  g

let parse ?(permissive = false) filename =
  let in_channel = Stdlib.open_in filename in
  Fun.protect
    ~finally:(fun () -> Stdlib.close_in in_channel)
    (fun () -> parse_in_channel ~permissive filename in_channel)
OCaml

Innovation. Community. Security.