Source file fin.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
module Make (Reals : Basic_intf.Reals) = struct
type r = Reals.t
type 'a finfn = ('a, r) Stats_intf.fin_fun
type 'a prb = ('a, r) Stats_intf.fin_prb
type 'a mes = ('a, r) Stats_intf.fin_mes
type state = Random.State.t
let of_fun : int -> (int -> r) -> int finfn =
fun len f ->
if len < 0 then invalid_arg "of_fun" ;
let iterator f =
for i = 0 to len - 1 do
f i
done
in
Vec.Vec (iterator, fun i -> if i < 0 || i >= len then Reals.zero else f i)
let of_array : r array -> int finfn =
fun array ->
let iterator f =
for i = 0 to Array.length array - 1 do
f i
done
in
Vec.Vec
( iterator,
fun i ->
if i < 0 || i >= Array.length array then Reals.zero
else Array.unsafe_get array i )
let of_assoc :
type k.
(module Hashtbl.S with type key = k) ->
(k * r) array ->
(k, r) Stats_intf.fin_fun =
fun (type k) (module H : Hashtbl.S with type key = k) array ->
let table = H.of_seq (Array.to_seq array) in
let iterator f = H.iter (fun k _ -> f k) table in
Vec.Vec
(iterator, fun k -> try H.find table k with Not_found -> Reals.zero)
let measure : 'a finfn -> 'a mes =
fun (type a) (finfn : a finfn) ->
let (Vec.Vec (iter, f)) = finfn in
let acc = ref Reals.zero in
iter (fun i ->
let w = f i in
acc := Reals.add !acc w ;
if Reals.(w < zero) then invalid_arg "measure: negative weight") ;
Stats_intf.M { total_mass = !acc; fn = finfn }
let probability : 'a finfn -> 'a prb =
fun (type a) (finfn : a finfn) ->
let (Vec.Vec (iter, f)) = finfn in
let acc = ref Reals.zero in
iter (fun i -> acc := Reals.add !acc (f i)) ;
let total_mass = !acc in
if not Reals.(total_mass = Reals.one) then
Format.kasprintf
invalid_arg
"probability: mass do not sum up to 1 (%a)"
Reals.pp
total_mass ;
iter (fun i ->
if Reals.(f i < zero) then invalid_arg "probability: negative mass") ;
Stats_intf.P { fn = finfn }
let as_measure (Stats_intf.P { fn }) =
Stats_intf.M { total_mass = Reals.one; fn }
let total_mass (type t) (M { total_mass; _ } : t mes) : r = total_mass
let normalize (type t) (M { total_mass; fn } : t mes) :
(t, r) Stats_intf.fin_prb =
let (Vec.Vec (shape, f)) = fn in
if Reals.equal total_mass Reals.zero then
invalid_arg "normalize: null measure" ;
let inv_w = Reals.(one / total_mass) in
let fn = Vec.Vec (shape, fun i -> Reals.mul (f i) inv_w) in
P { fn }
let sample (type t) (M { total_mass; fn } : t mes) rng_state =
let exception Sampled of t in
let (Vec.Vec (iter, f)) = fn in
let r = Reals.(total_mass * lebesgue rng_state) in
let cumu = ref Reals.zero in
try
let _ =
iter (fun i ->
let w = f i in
let c = Reals.add !cumu w in
if Reals.(r <= c) then raise (Sampled i) ;
cumu := c)
in
assert false
with Sampled x -> x
let counts_of_empirical (type t) (module H : Hashtbl.S with type key = t)
(p : t array) : t mes =
let table = H.create (Array.length p) in
Array.iter
(fun elt ->
match H.find_opt table elt with
| None -> H.add table elt 1
| Some c -> H.replace table elt (c + 1))
p ;
let iter f = H.iter (fun k _ -> f k) table in
let total_mass = Reals.of_int (Array.length p) in
let fn =
Vec.Vec
( iter,
fun k ->
try Reals.of_int (H.find table k) with Not_found -> Reals.zero )
in
M { total_mass; fn }
let uniform (type t) (arr : t array) : t prb =
let len = Array.length arr in
if Int.equal len 0 then failwith "uniform: empty array"
else
let prb = Reals.(div one (of_int len)) in
let iter f =
for i = 0 to Array.length arr - 1 do
f arr.(i)
done
in
P { fn = Vec.Vec (iter, fun _elt -> prb) }
let eval_prb (type t) (P { fn } : t prb) (x : t) : r =
let (Vec.Vec (_, f)) = fn in
f x
let eval_mes (type t) (M { fn; total_mass = _ } : t mes) (x : t) : r =
let (Vec.Vec (_, f)) = fn in
f x
let iter_prb (type t) (P { fn } : t prb) f =
let (Vec.Vec (iter, p)) = fn in
iter (fun x -> f x (p x))
let iter_mes (type t) (M { fn; total_mass = _ } : t mes) f =
let (Vec.Vec (iter, p)) = fn in
iter (fun x -> f x (p x))
let integrate (type t) (M { fn; total_mass = _ } : t mes) (f : t -> r) : r =
let acc = ref Reals.zero in
let (Vec.Vec (iter, m)) = fn in
iter (fun x -> acc := Reals.(!acc + (m x * f x))) ;
!acc
let list_of_measure (type t) (M { fn; total_mass = _ } : t mes) =
let (Vec.Vec (iter, m)) = fn in
let acc = ref [] in
iter (fun x -> acc := (x, m x) :: !acc) ;
`Measure (List.rev !acc)
let list_of_probability (type t) (P { fn } : t prb) =
let (Vec.Vec (iter, m)) = fn in
let acc = ref [] in
iter (fun x -> acc := (x, m x) :: !acc) ;
`Probability (List.rev !acc)
let pp_fin_mes :
type a.
(Format.formatter -> a -> unit) ->
Format.formatter ->
(a, r) Stats_intf.fin_mes ->
unit =
fun pp fmtr den ->
let (`Measure l) = list_of_measure den in
Format.fprintf
fmtr
"@[<h>%a@]"
(Format.pp_print_list (fun elt_fmt (elt, pr) ->
Format.fprintf elt_fmt "(%a, %a);@," pp elt Reals.pp pr))
l
let pp_fin_mes_by_measure :
type a.
(Format.formatter -> a -> unit) ->
Format.formatter ->
(a, r) Stats_intf.fin_mes ->
unit =
fun pp fmtr den ->
let (`Measure l) = list_of_measure den in
let l = List.sort (fun (_, r1) (_, r2) -> Reals.compare r1 r2) l in
Format.fprintf
fmtr
"@[<h>%a@]"
(Format.pp_print_list (fun elt_fmt (elt, pr) ->
Format.fprintf elt_fmt "(%a, %a);@," pp elt Reals.pp pr))
l
let coin ~bias : bool prb =
if Reals.(bias < zero || bias > one) then invalid_arg "coin: invalid bias"
else
let compl = Reals.(one - bias) in
let iter f =
f true ;
f false
in
let fn = Vec.Vec (iter, function true -> bias | false -> compl) in
probability fn
let bincoeff n k =
let n = Reals.of_int n in
let rec loop i acc =
if Int.equal i (k + 1) then acc
else
let fi = Reals.of_int i in
loop (i + 1) Reals.(acc * ((n + one - fi) / fi))
in
loop 1 Reals.one
let binomial (coin : bool prb) n =
let p = eval_prb coin true in
let not_p = eval_prb coin false in
let elements =
Array.init n (fun k ->
let n_minus_k = n - k in
Reals.(k, bincoeff n k * npow p k * npow not_p n_minus_k))
in
let finfn = of_assoc (module Helpers.Int_table) elements in
normalize @@ measure finfn
let mean_generic (type t)
(module L : Basic_intf.Module with type t = t and type R.t = Reals.t)
(M { fn; total_mass = _ } : t mes) =
let (Vec.Vec (iter, m)) = fn in
let acc = ref L.zero in
iter (fun x -> acc := L.add (L.smul (m x) x) !acc) ;
!acc
let mean (dist : r mes) = integrate dist Fun.id
let variance (M { fn; total_mass = _ } as dist : r mes) =
let (Vec.Vec (iter, m)) = fn in
let mean = mean dist in
let acc = ref Reals.zero in
iter (fun x ->
let acc' =
let open Reals in
let delta = x - mean in
let delta_squared = delta * delta in
!acc + (delta_squared * m x)
in
acc := acc') ;
!acc
let quantile (type elt) (module O : Basic_intf.Ordered with type t = elt)
(M { fn; total_mass } : elt mes) (p : r) =
if Reals.(p < zero) || Reals.(p > one) then
invalid_arg "quantile (invalid p)" ;
if Reals.(total_mass = Reals.zero) then invalid_arg "quantile (zero mass)" ;
let p = Reals.(p * total_mass) in
let (Vec.Vec (iter, m)) = fn in
let elts = ref [] in
iter (fun x -> elts := (x, m x) :: !elts) ;
let arr = Array.of_list !elts in
Array.sort (fun (x, _) (y, _) -> O.compare x y) arr ;
let acc = ref Reals.zero in
let exception Found of elt in
try
Array.iter
(fun (x, q) ->
(acc := Reals.(!acc + q)) ;
if Reals.(!acc >= p) then raise (Found x))
arr ;
fst arr.(Array.length arr - 1)
with Found elt -> elt
let fold_union (type x) (module H : Hashtbl.S with type key = x) f
(m1 : x mes) (m2 : x mes) acc =
let (M { fn = Vec.Vec (iter1, m1); total_mass = _ }) = m1 in
let (M { fn = Vec.Vec (iter2, m2); total_mass = _ }) = m2 in
let table = H.create 127 in
iter1 (fun x -> H.add table x (m1 x)) ;
let acc = ref acc in
iter2 (fun y ->
let r = m2 y in
match H.find_opt table y with
| None -> acc := f y Reals.zero r !acc
| Some r' ->
H.remove table y ;
acc := f y r r' !acc) ;
H.fold (fun x r acc -> f x r Reals.zero acc) table !acc
end
[@@inline]
module Float = struct
include Make (Basic_impl.Reals.Float)
module Dist = struct
let kl h m1 m2 =
fold_union h (fun _ r1 r2 acc -> acc +. (r1 *. log (r1 /. r2))) m1 m2 0.0
let lp h ~p m1 m2 =
if p <. 1. then invalid_arg "lp: p < 1" ;
let res =
fold_union
h
(fun _ r1 r2 acc -> acc +. (abs_float (r1 -. r2) ** p))
m1
m2
0.0
in
res ** (1. /. p)
let maxf x y = if x <. y then y else x
let linf h m1 m2 =
fold_union
h
(fun _ r1 r2 acc -> maxf acc (abs_float (r1 -. r2)))
m1
m2
0.0
end
end
module Rational = struct
include Make (Basic_impl.Reals.Rational)
module Dist = struct
let linf h m1 m2 =
fold_union
h
(fun _ r1 r2 acc -> Q.max acc (Q.abs (Q.sub r1 r2)))
m1
m2
Q.zero
end
end