package caisar

  1. Overview
  2. Docs

Source file tree.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*  Copyright (C) 2024                                                    *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

type tree =
  | Split of {
      split_indice : int;
      split_condition : float;
      left : tree;
      right : tree;
      missing : [ `Left ];
    }
  | Leaf of { leaf_value : float }

type op =
  | Identity
  | Sigmoid

type t = {
  base_score : float;
  trees : tree array;
  after_sum : op;
}
(** the value is [op(base_score + sum(tree))] *)

let predict t input =
  let rec aux input = function
    | Split s -> (
      match Input.get input s.split_indice with
      | None ->
        aux input s.left (* TODO: check if missing can be on the right *)
      | Some v when v < s.split_condition -> aux input s.left
      | _ -> aux input s.right)
    | Leaf l -> l.leaf_value
  in
  let sum =
    Array.fold_left (fun acc t -> acc +. aux input t) t.base_score t.trees
  in
  match t.after_sum with Identity -> sum | Sigmoid -> Predict.sigmoid sum

let convert_tree (t : Parser.tree) : tree =
  let rec aux node =
    assert (-1 <= t.left_children.(node));
    if t.left_children.(node) = -1
    then Leaf { leaf_value = t.split_conditions.(node) }
    else
      Split
        {
          split_indice = t.split_indices.(node);
          split_condition = t.split_conditions.(node);
          left = aux t.left_children.(node);
          right = aux t.right_children.(node);
          missing = `Left;
        }
  in
  aux 0

let convert_trees (t : Parser.t) (gb : Parser.gbtree) : t =
  let base_score =
    let base_score = Float.of_string t.learner.learner_model_param.base_score in
    match t.learner.objective with
    | Parser.Reg_squarederror _ -> base_score
    | Parser.Reg_pseudohubererror _ -> invalid_arg "unimplemented"
    | Parser.Reg_squaredlogerror _ -> base_score (* ? *)
    | Parser.Reg_linear _ -> base_score (* ? *)
    | Parser.Binary_logistic _ -> 0.
  in

  let trees = Array.map convert_tree gb.trees in
  (* From regression_loss.h PredTransform *)
  let after_sum =
    match t.learner.objective with
    | Parser.Reg_squarederror _ -> Identity
    | Parser.Reg_pseudohubererror _ -> invalid_arg "unimplemented"
    | Parser.Reg_squaredlogerror _ -> Identity
    | Parser.Reg_linear _ -> Identity
    | Parser.Binary_logistic _ -> Sigmoid
  in
  (* Format.eprintf "%f -> %f@." sum pred; *)
  { base_score; trees; after_sum }

let convert (t : Parser.t) =
  match t.learner.gradient_booster with
  | Parser.Gbtree gbtree -> convert_trees t gbtree
  | Parser.Gblinear _ -> assert false
  | Parser.Dart _ -> assert false
OCaml

Innovation. Community. Security.