package caisar

  1. Overview
  2. Docs

Source file predict.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*  Copyright (C) 2024                                                    *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

let compute_tree (t : Parser.tree) (input : Input.t) : int =
  let rec aux node =
    assert (-1 <= t.left_children.(node));
    if t.left_children.(node) = -1
    then node
    else
      match Input.get input t.split_indices.(node) with
      | None ->
        aux t.left_children.(node)
        (* TODO: check if missing can be on the right *)
      | Some v when v < t.split_conditions.(node) -> aux t.left_children.(node)
      | _ -> aux t.right_children.(node)
  in
  aux 0

let sigmoid x =
  (* original is float instead of double precision *)
  let kEps = 1e-16 in
  let x = Float.min (-.x) 88.7 in
  let denom = exp x +. 1.0 +. kEps in
  let y = 1.0 /. denom in
  y

let compute_trees (t : Parser.t) (gb : Parser.gbtree) input : float =
  let base_score =
    let base_score = Float.of_string t.learner.learner_model_param.base_score in
    match t.learner.objective with
    | Parser.Reg_squarederror _ -> base_score
    | Parser.Reg_pseudohubererror _ -> invalid_arg "unimplemented"
    | Parser.Reg_squaredlogerror _ -> base_score (* ? *)
    | Parser.Reg_linear _ -> base_score (* ? *)
    | Parser.Binary_logistic _ -> 0.
  in

  let sum =
    Array.fold_left
      (fun acc t ->
        let node = compute_tree t input in
        let v = t.split_conditions.(node) in
        (* Format.eprintf "node:%i -> %f@." node v; *)
        acc +. v)
      base_score gb.trees
  in
  (* From regression_loss.h PredTransform *)
  let pred =
    match t.learner.objective with
    | Parser.Reg_squarederror _ -> sum
    | Parser.Reg_pseudohubererror _ -> invalid_arg "unimplemented"
    | Parser.Reg_squaredlogerror _ -> sum
    | Parser.Reg_linear _ -> sum
    | Parser.Binary_logistic _ -> sigmoid sum
  in
  (* Format.eprintf "%f -> %f@." sum pred; *)
  pred

let predict (t : Parser.t) input =
  match t.learner.gradient_booster with
  | Parser.Gbtree gbtree -> compute_trees t gbtree input
  | Parser.Gblinear _ -> assert false
  | Parser.Dart _ -> assert false
OCaml

Innovation. Community. Security.