package caisar

  1. Overview
  2. Docs

Source file ovo.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*  Copyright (C) 2025                                                    *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

open Base
open Result
module Format = Stdlib.Format
module Seq = Stdlib.Seq

(* Links: *
   https://datascience.stackexchange.com/questions/18374/predicting-probability-from-scikit-learn-svc-decision-function-with-decision-fun
   explains how the ovo (one-versus-one) procedure works *
   https://scikit-learn.org/stable/modules/svm.html#multi-class-classification a
   description of the SVMs. *
   https://github.com/abstract-machine-learning/saver#classifier-format. In a
   description of the input format (notice that there are broken links). *)

(* BASIC PARSING TOOLS *)
type parser = {
  input : Csv.in_channel;
  mutable tokens : string list;
}

let ovo_format_error s =
  Error (Format.sprintf "OVO format error: %s condition not satisfied." s)

let create_parser inc =
  { input = Csv.of_channel ~separator:' ' inc; tokens = [] }

let rec peek_token p =
  (* returns (without consuming) the top token *)
  match p.tokens with
  | "" :: tl ->
    (* can have empty strings if there are trailing spaces at the end of a
       line *)
    p.tokens <- tl;
    peek_token p
  | hd :: _ -> hd
  | [] ->
    let sp = Csv.next p.input in
    p.tokens <- sp;
    peek_token p

let read_token p =
  let _ = peek_token p in
  match p.tokens with
  | hd :: tl ->
    p.tokens <- tl;
    Ok hd
  | _ -> Error "EOF"

let read_int ?(msg = "") p =
  (* returns and consumes the top token as an int *)
  read_token p >>= fun str ->
  try Ok (Int.of_string str)
  with Failure _ ->
    ovo_format_error (Format.sprintf "(%s) not an int (%s)" str msg)

let read_float ?(msg = "") p =
  (* returns and consumes the top token as a float *)
  read_token p >>= fun str ->
  try Ok (Float.of_string str)
  with Failure _ ->
    ovo_format_error (Format.sprintf "(%s) not a float (%s)" str msg)

let read_keyword p k =
  (* returns and consumes the top token as the specified string *)
  read_token p >>= fun tok ->
  if String.equal tok k
  then Ok ()
  else ovo_format_error (Format.sprintf "expected keyword (%s) was (%s)" k tok)

let read_float_array parser msg size =
  (* returns and consumes an array of floats of specified size *)
  let rec fill_array nb arr =
    if nb = size
    then Ok arr
    else
      read_float ~msg parser >>= fun f ->
      arr.(nb) <- f;
      fill_array (nb + 1) arr
  in
  fill_array 0 (Array.create ~len:size 0.0)

let read_2_dim_float parser msg size1 param_size2 =
  (* returns and consumes a 2D array of floats of specified size param_size2
     indicates the size as a function of the index of [0;size1-1]. *)
  let rec fill_mat nb mat =
    if nb = size1
    then Ok mat
    else
      read_float_array parser msg (param_size2 size1) >>= fun line ->
      mat.(nb) <- line;
      fill_mat (nb + 1) mat
  in
  fill_mat 0 (Array.create ~len:size1 [||])

let read_3_dim_float parser msg size1 param_size2 param_size3 =
  let rec fill_3d nb tensor =
    if nb = size1
    then Ok tensor
    else
      read_2_dim_float parser msg (param_size2 nb) param_size3 >>= fun mat ->
      tensor.(nb) <- mat;
      fill_3d (nb + 1) tensor
  in
  fill_3d 0 (Array.create ~len:size1 [| [||] |])

let check_eof parser =
  (* verifies that the end of the channel has been reached. *)
  try
    let _ = peek_token parser in
    ovo_format_error "File not finished"
  with End_of_file -> Ok ()

(* OVO DATA STRUCTURE *)
type sv = float array

type class_descriptor = {
  name : string; (* name of the class *)
  nb_svs : int; (* number of support vectors associated with the class *)
}

type kernel_type =
  | Linear
  | Poly of {
      gamma : float;
      degree : float;
      coef : float;
    }
(* | Rbf of float --- not implemented yet. *)

type t = {
  nb_ins : int; (* the number of inputs *)
  nb_classes : int; (* the number of classes *)
  name_and_nb_sv_of_class : class_descriptor array;
    (* a description of each class *)
  start_of_sv_of_class : int array;
    (* the index of the first support vector associated with a class. When
       considered globally, the support vectors associated with class [cl] are
       indexed from [start_of_sv_of_class.(cl)] to [start_of_sv_of_class.(cl) +
       (number_sv_of_class cl)] (the latter of which is saved in the class
       description). *)
  dual_coefs : float array array;
    (* the dual coefs. The dual coefficient of the [i]th SV of [cl] associated
       with [cl'] is [dual_coefs.(cl').(global i)] if [cl' < cl] and
       [dual_coefs.(cl'-1).(global i)] otherwise. The reason for this is that
       [cl] can never equal [cl'], so the values for [dual_coefs.(cl').(global
       i)] is irrelevant; as a consequence, the values when [cl' > cl] are moved
       to the cells [cl'-1], hence the condition above. Cf. the implementation
       of [dual_coef below]. *)
  support_vectors : sv array array;
    (* the support vectors associated with each class *)
  intercept : float array;
    (* the intercept associated with each pair of class. The intercept of
       [(cl,cl')] is [intercept.(pair_index cl cl')] *)
  k : kernel_type; (* the type of kernel for this SVM. *)
}

let pair_index ovo cl cl' =
  (* [pair_index ovo j l] is a unique index associated with the pair of classes
     [(j,l)] in ovo [ovo]. It is defined for [0 <= j < l < c] as follows:

     [pair_index((j,l)) = ((c * (c - 1)) / 2) - ((c-j) * (c-j-1) / 2) + l - j -
     1]

     where [c] is [ovo.nb_classes] (the domain of the input variables is not
     checked). This function satisfies the following properties:

     [pair_index((0,1)) = 0]

     [pair_index((j,l+1)) = 1 + pair_index((j,l))]

     and

     [pair_index((j+1,j+1+1)) = 1 + pair_index((j,c-1))]. I.e., the positions
     are ordered lexicographically starting with the first index [j] then the
     second.

     In other words, the positions for the following pairs are:

     - [pair_index(0,1) = 0]

     - [pair_index(0,2)] -> [1]

     - ...

     - [pair_index(0,c-1) = c-2]

     - [pair_index(1,2) = c-1]

     - [pair_index(1,3) = c]

     - ...

     - [pair_index((c-2),(c-1)) = c * (c-1) / 2 - 1]. *)
  let c = ovo.nb_classes in
  (c * (c - 1) / 2) - ((c - cl) * (c - cl - 1) / 2) + cl' - cl - 1

let global ovo cl i = ovo.start_of_sv_of_class.(cl) + i
(* [global ovo cl i] is the "global" index of the [i]th SV of class [cl]. *)

(* PARSING METHODS PER SE *)

let parse_header parser =
  (* reads 'ovo' nb_ins nb_classes and returns Ok(nb_ins, nb_classes) *)
  let open Result in
  read_keyword parser "ovo" >>= fun _ ->
  read_int ~msg:"nb_ins" parser >>= fun nb_ins ->
  read_int ~msg:"nb_classes" parser >>= fun nb_classes -> Ok (nb_ins, nb_classes)

let parse_classes_description parser nb_classes =
  (* Reads "name_of_class number_of_sv" [nb_classes] times *)
  let rec fill_descriptions nb descriptions =
    (* fills the array of descriptions *)
    if nb = nb_classes
    then Ok descriptions
    else
      read_token parser >>= fun name ->
      read_int ~msg:"nb SVs of class" parser >>= fun nb_svs ->
      descriptions.(nb) <- { name; nb_svs };
      fill_descriptions (nb + 1) descriptions
  in
  let swap_first_two_descriptions descriptions =
    let desc0 = descriptions.(0) in
    let desc1 = descriptions.(1) in
    descriptions.(0) <- desc1;
    descriptions.(1) <- desc0
  in
  Array.init nb_classes ~f:(fun _ -> { name = ""; nb_svs = -1 })
  |> fill_descriptions 0
  >>= fun descriptions ->
  if nb_classes = 2 then swap_first_two_descriptions descriptions;
  Ok descriptions

let parse_support_vectors parser nb_ins name_and_nb_sv_of_class =
  (* Parses the SVs. This assumes that the SV are enumerated as follows:
     first_param_of_first_sv ... last_param_of_first_sv ...
     first_param_of_last_sv ... last_param_of_last_sv

     (no need to have end_of_line separators) where each param is a float. The
     SVs are assumed to be ordered, with the SVs associated with the first class
     first, then the SVs associated with the second class, etc. *)
  read_3_dim_float parser "SV parameters"
    (Array.length name_and_nb_sv_of_class)
    (fun c -> name_and_nb_sv_of_class.(c).nb_svs)
    (fun _ -> nb_ins)

let parse_kernel_type parser nb_ins =
  (* parses the description of the kernel function *)
  read_token parser >>= fun tok ->
  if String.equal tok "linear"
  then Ok Linear
  else if String.equal tok "poly" || String.equal tok "polynomial"
  then
    let first = peek_token parser in
    (* is there a neat way to factorise this if/then/else? *)
    if String.equal first "gamma"
    then
      read_token parser >>= fun _ ->
      read_float_array parser "poly kernel parameters" 3 >>= function
      | [| gamma; degree; coef |] -> Ok (Poly { gamma; degree; coef })
      | _ -> assert false (* Should have 3 parameters *)
    else
      read_float_array parser "poly kernel parameters" 2 >>= function
      | [| degree; coef |] ->
        Ok (Poly { gamma = 1.0 /. Float.of_int nb_ins; degree; coef })
      | _ -> assert false (* Should have 2 parameters *)
  else ovo_format_error "kernel"

let nb_ins ovo = ovo.nb_ins
let nb_classes ovo = ovo.nb_classes
let class_name ovo cl = ovo.name_and_nb_sv_of_class.(cl).name

let nb_svs name_and_nb_sv_of_class =
  (* calculates the number of Support Vectors *)
  Array.fold_right name_and_nb_sv_of_class ~init:0 ~f:(fun cdesc sum ->
    sum + cdesc.nb_svs)

let start_of_svs name_and_nb_sv_of_class =
  (* [start_of_svs name_and_nb_sv_of_class] is an array that for each class [cl]
     returns the index of the first SV associated with this class (according to
     the description of each class) as specified in
     [name_and_nb_sv_of_class]). *)
  let nb_classes = Array.length name_and_nb_sv_of_class in
  let cur_index = ref 0 in
  Array.init nb_classes ~f:(fun cur_class ->
    let result = !cur_index in
    let () =
      cur_index := name_and_nb_sv_of_class.(cur_class).nb_svs + !cur_index
    in
    result)

let parse_dual_coefs parser name_and_nb_sv_of_class =
  (* parses the dual coefficients *)
  let nb_classes = Array.length name_and_nb_sv_of_class in
  let nb_svs = nb_svs name_and_nb_sv_of_class in
  read_2_dim_float parser "dual coefficients" (nb_classes - 1) (fun _ -> nb_svs)

let parse_intercept parser nb_classes =
  (* parses the intercept *)
  read_float_array parser "dual coefficients" (nb_classes * (nb_classes - 1) / 2)

let parse parser =
  (* Parses the OVO. Look at the description of the input language in the mli
     file. *)
  let open Result in
  parse_header parser >>= fun (nb_ins, nb_classes) ->
  parse_kernel_type parser nb_ins >>= fun k ->
  parse_classes_description parser nb_classes >>= fun name_and_nb_sv_of_class ->
  parse_dual_coefs parser name_and_nb_sv_of_class >>= fun dual_coefs ->
  parse_support_vectors parser nb_ins name_and_nb_sv_of_class
  >>= fun support_vectors ->
  parse_intercept parser nb_classes >>= fun intercept ->
  check_eof parser >>= fun () ->
  let start_of_sv_of_class = start_of_svs name_and_nb_sv_of_class in
  Ok
    {
      nb_ins;
      nb_classes;
      name_and_nb_sv_of_class;
      start_of_sv_of_class;
      dual_coefs;
      support_vectors;
      intercept;
      k;
    }

let parse inc =
  let parser = create_parser inc in
  match parse parser with Error e -> failwith e | x -> x

let parse filename =
  let in_channel = Stdlib.open_in filename in
  Stdlib.Fun.protect
    ~finally:(fun () -> Stdlib.close_in in_channel)
    (fun () -> parse in_channel)

(* ACCESSES *)

let nb_svs ovo =
  Array.fold ovo.name_and_nb_sv_of_class ~init:0 ~f:(fun sum desc ->
    sum + desc.nb_svs)

let svs ovo =
  (* [svs ovo] are the support vectors of [ovo]. *)
  let rec aux acc class_number =
    if class_number = ovo.nb_classes
    then acc
    else
      aux
        (List.concat
           [
             ovo.support_vectors.(class_number) |> Array.to_list |> List.rev;
             acc;
           ])
        (class_number + 1)
  in
  aux [] 0 |> List.rev

let node_of_constant data = Nir.Node.create @@ Nir.Node.Constant { data }

(* Transformation OVO -> Nier *)

let build_kernel ovo input_node (* shape (n) *) =
  let data =
    Nir.Gentensor.of_float_matrix ~trans:true (svs ovo |> Array.of_list)
  in
  (* (n,m) *)
  let input2 = node_of_constant data in
  (* (n,m) *)
  let product =
    Nir.Node.create @@ Nir.Node.Matmul { input1 = input_node; input2 }
  in
  match ovo.k with
  | Linear -> product (* shape (m,) *)
  | Poly { gamma; degree; coef } ->
    let constant_node shape v =
      node_of_constant @@ Nir.Gentensor.create_const_float shape v
    in
    let shape = Nir.Node.compute_shape product in
    let constant_gamma = constant_node shape gamma in
    let constant_degree = constant_node shape degree in
    let constant_coef = constant_node shape coef in
    Nir.Node.(
      let input1 = (product * constant_gamma) + constant_coef in
      let input2 = constant_degree in
      Nir.Node.create @@ Nir.Node.Pow { input1; input2 })

let pairs_of_classes nb_classes =
  let rec aux j l () =
    if phys_equal j nb_classes
    then Seq.Nil
    else if phys_equal l nb_classes
    then aux (j + 1) (j + 2) ()
    else Seq.Cons ((j, l), aux j (l + 1))
  in
  aux 0 1

let to_nn ovo =
  (* CONSTANTS *)
  let c = ovo.nb_classes in
  let p = c * (c - 1) / 2 in
  let s = nb_svs ovo in
  (* HELPERS *)
  let ( ** ) n1 n2 =
    Nir.Node.create @@ Nir.Node.Matmul { input1 = n1; input2 = n2 }
  in
  let sign node = Nir.Node.create @@ Nir.Node.Sign { input = node } in
  let add_one_dimension input =
    Nir.Node.reshape
      (Nir.Node.compute_shape input
      |> Nir.Shape.to_list |> List.cons 1 |> Nir.Shape.of_list)
      input
  in

  (* STEP 1: kernel *)
  let input_node =
    (* shape (n,) *)
    Nir.Node.create
      (Nir.Node.Input { shape = Nir.Shape.of_array [| ovo.nb_ins |] })
  in
  let kernel = build_kernel ovo input_node |> add_one_dimension in

  (* Shape (1,s) *)

  (* STEP 2: binary classification of each pair (cl1,cl2) *)
  let dual_coef cl cl' i =
    (* the dual coefficient of the [i]th SV of [cl] associated with [cl']. [i]
       should be between [0] and [Array.length ovo.name_and_nb_sv_of_class.(cl)
       - 1]. *)
    ovo.dual_coefs.(if cl' > cl then cl' - 1 else cl').(global ovo cl i)
  in
  let dual_coefs =
    (* matrix of dual coefs *)
    let mat = Array.make_matrix ~dimx:s ~dimy:p 0.0 in
    let () =
      pairs_of_classes c
      |> Seq.iter (fun (cl1, cl2) ->
           let pair_idx = pair_index ovo cl1 cl2 in
           let () =
             for i = 0 to ovo.name_and_nb_sv_of_class.(cl1).nb_svs - 1 do
               let idx = global ovo cl1 i in
               mat.(idx).(pair_idx) <- dual_coef cl1 cl2 i
             done
           in
           let () =
             for i = 0 to ovo.name_and_nb_sv_of_class.(cl2).nb_svs - 1 do
               let idx = global ovo cl2 i in
               mat.(idx).(pair_idx) <- dual_coef cl2 cl1 i
             done
           in
           ())
    in
    node_of_constant @@ Nir.Gentensor.of_float_matrix ~trans:false mat
  in
  let intercept =
    node_of_constant @@ Nir.Gentensor.of_float_array ovo.intercept
  in
  let row_ovo_scores = Nir.Node.(intercept + (kernel ** dual_coefs)) in
  let signed_ovo_scores = sign row_ovo_scores in

  (* STEP 3: Adding up wins of each class = outcome of Nir *)
  let score_filter =
    (* indicates which outputs of sign are used to compute the score of each
       class *)
    let mat = Array.make_matrix ~dimx:p ~dimy:c 0.0 in
    let () =
      pairs_of_classes c
      |> Seq.iter (fun (cl1, cl2) ->
           let pair_idx = pair_index ovo cl1 cl2 in
           let () = mat.(pair_idx).(cl1) <- 1. in
           let () = mat.(pair_idx).(cl2) <- -1. in
           ())
    in
    node_of_constant @@ Nir.Gentensor.of_float_matrix ~trans:false mat
  in
  let scores = signed_ovo_scores ** score_filter in
  Nir.Ngraph.create scores

(* eof *)
OCaml

Innovation. Community. Security.