package caisar
A platform for characterizing the safety and robustness of artificial intelligence based software
Install
Dune Dependency
Authors
Maintainers
Sources
caisar-2.1.tbz
sha256=1b25c8668d428bcfc83c95147b6e45ff0a3bfa05ecd11369d12e963e29819e2e
sha512=edc7d7c0e96802811de3cb1caa3d14cc3d867ee7310748e8188eca9246a362549545c7816c8037511931dc4b7770b5ccc11b0d03abe8843b7c4db7880bf8e1fd
doc/src/caisar.xgboost/tree.ml.html
Source file tree.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
(**************************************************************************) (* *) (* This file is part of CAISAR. *) (* *) (* Copyright (C) 2024 *) (* CEA (Commissariat à l'énergie atomique et aux énergies *) (* alternatives) *) (* *) (* You can redistribute it and/or modify it under the terms of the GNU *) (* Lesser General Public License as published by the Free Software *) (* Foundation, version 2.1. *) (* *) (* It is distributed in the hope that it will be useful, *) (* but WITHOUT ANY WARRANTY; without even the implied warranty of *) (* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) (* GNU Lesser General Public License for more details. *) (* *) (* See the GNU Lesser General Public License version 2.1 *) (* for more details (enclosed in the file licenses/LGPLv2.1). *) (* *) (**************************************************************************) type tree = | Split of { split_indice : int; split_condition : float; left : tree; right : tree; missing : [ `Left ]; } | Leaf of { leaf_value : float } type op = | Identity | Sigmoid type t = { base_score : float; trees : tree array; after_sum : op; } (** the value is [op(base_score + sum(tree))] *) let predict t input = let rec aux input = function | Split s -> ( match Input.get input s.split_indice with | None -> aux input s.left (* TODO: check if missing can be on the right *) | Some v when v < s.split_condition -> aux input s.left | _ -> aux input s.right) | Leaf l -> l.leaf_value in let sum = Array.fold_left (fun acc t -> acc +. aux input t) t.base_score t.trees in match t.after_sum with Identity -> sum | Sigmoid -> Predict.sigmoid sum let convert_tree (t : Parser.tree) : tree = let rec aux node = assert (-1 <= t.left_children.(node)); if t.left_children.(node) = -1 then Leaf { leaf_value = t.split_conditions.(node) } else Split { split_indice = t.split_indices.(node); split_condition = t.split_conditions.(node); left = aux t.left_children.(node); right = aux t.right_children.(node); missing = `Left; } in aux 0 let convert_trees (t : Parser.t) (gb : Parser.gbtree) : t = let base_score = let base_score = Float.of_string t.learner.learner_model_param.base_score in match t.learner.objective with | Parser.Reg_squarederror _ -> base_score | Parser.Reg_pseudohubererror _ -> invalid_arg "unimplemented" | Parser.Reg_squaredlogerror _ -> base_score (* ? *) | Parser.Reg_linear _ -> base_score (* ? *) | Parser.Binary_logistic _ -> 0. in let trees = Array.map convert_tree gb.trees in (* From regression_loss.h PredTransform *) let after_sum = match t.learner.objective with | Parser.Reg_squarederror _ -> Identity | Parser.Reg_pseudohubererror _ -> invalid_arg "unimplemented" | Parser.Reg_squaredlogerror _ -> Identity | Parser.Reg_linear _ -> Identity | Parser.Binary_logistic _ -> Sigmoid in (* Format.eprintf "%f -> %f@." sum pred; *) { base_score; trees; after_sum } let convert (t : Parser.t) = match t.learner.gradient_booster with | Parser.Gbtree gbtree -> convert_trees t gbtree | Parser.Gblinear _ -> assert false | Parser.Dart _ -> assert false
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>