package caisar
A platform for characterizing the safety and robustness of artificial intelligence based software
Install
Dune Dependency
Authors
Maintainers
Sources
caisar-2.0.tbz
sha256=3d24d2940eed0921acba158a8970687743c401c6a99d0aac8ed6dcfedca1429c
sha512=0b4484c0e080b8ba22722fe9d5665f9015ebf1648ac89c566a978dd54e3e061acb63edd92e078eed310e26f3e8ad2c48f3682a24af2acb1f0633da12f7966a38
doc/src/caisar.xgboost/tree.ml.html
Source file tree.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
(**************************************************************************) (* *) (* This file is part of CAISAR. *) (* *) (* Copyright (C) 2024 *) (* CEA (Commissariat à l'énergie atomique et aux énergies *) (* alternatives) *) (* *) (* You can redistribute it and/or modify it under the terms of the GNU *) (* Lesser General Public License as published by the Free Software *) (* Foundation, version 2.1. *) (* *) (* It is distributed in the hope that it will be useful, *) (* but WITHOUT ANY WARRANTY; without even the implied warranty of *) (* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) (* GNU Lesser General Public License for more details. *) (* *) (* See the GNU Lesser General Public License version 2.1 *) (* for more details (enclosed in the file licenses/LGPLv2.1). *) (* *) (**************************************************************************) type tree = | Split of { split_indice : int; split_condition : float; left : tree; right : tree; missing : [ `Left ]; } | Leaf of { leaf_value : float } type op = | Identity | Sigmoid type t = { base_score : float; trees : tree array; after_sum : op; } (** the value is [op(base_score + sum(tree))] *) let predict t input = let rec aux input = function | Split s -> ( match Input.get input s.split_indice with | None -> aux input s.left (* TODO: check if missing can be on the right *) | Some v when v < s.split_condition -> aux input s.left | _ -> aux input s.right) | Leaf l -> l.leaf_value in let sum = Array.fold_left (fun acc t -> acc +. aux input t) t.base_score t.trees in match t.after_sum with Identity -> sum | Sigmoid -> Predict.sigmoid sum let convert_tree (t : Parser.tree) : tree = let rec aux node = assert (-1 <= t.left_children.(node)); if t.left_children.(node) = -1 then Leaf { leaf_value = t.split_conditions.(node) } else Split { split_indice = t.split_indices.(node); split_condition = t.split_conditions.(node); left = aux t.left_children.(node); right = aux t.right_children.(node); missing = `Left; } in aux 0 let convert_trees (t : Parser.t) (gb : Parser.gbtree) : t = let base_score = let base_score = Float.of_string t.learner.learner_model_param.base_score in match t.learner.objective with | Parser.Reg_squarederror _ -> base_score | Parser.Reg_pseudohubererror _ -> invalid_arg "unimplemented" | Parser.Reg_squaredlogerror _ -> base_score (* ? *) | Parser.Reg_linear _ -> base_score (* ? *) | Parser.Binary_logistic _ -> 0. in let trees = Array.map convert_tree gb.trees in (* From regression_loss.h PredTransform *) let after_sum = match t.learner.objective with | Parser.Reg_squarederror _ -> Identity | Parser.Reg_pseudohubererror _ -> invalid_arg "unimplemented" | Parser.Reg_squaredlogerror _ -> Identity | Parser.Reg_linear _ -> Identity | Parser.Binary_logistic _ -> Sigmoid in (* Format.eprintf "%f -> %f@." sum pred; *) { base_score; trees; after_sum } let convert (t : Parser.t) = match t.learner.gradient_booster with | Parser.Gbtree gbtree -> convert_trees t gbtree | Parser.Gblinear _ -> assert false | Parser.Dart _ -> assert false
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>