package prbnmcn-stats

  1. Overview
  2. Docs

Source file fin_dist.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
module Make (Reals : Basic_intf.Reals) = struct
  type reals = Reals.t

  module type Fin_fun = sig
    type t

    module V : Basic_intf.Free_module with type R.t = Reals.t and type basis = t

    val weightmap : V.t
  end

  module type Fin_kernel = sig
    type t

    type u

    module V : Basic_intf.Free_module with type R.t = Reals.t and type basis = u

    val kernel : t -> V.t
  end

  type 'a fin_fun = (module Fin_fun with type t = 'a)

  type 'a fin_den = 'a fin_fun

  (* A finitely supported probability is normalized. *)
  type 'a fin_prb = 'a fin_fun

  type ('a, 'b) kernel = (module Fin_kernel with type t = 'a and type u = 'b)

  let sample_prb : type a. a fin_prb -> Random.State.t -> a =
    fun (type t) (module P : Fin_fun with type t = t) rng_state ->
     let exception Sampled of t in
     let r = Reals.lebesgue rng_state in
     try
       let _ =
         P.V.fold
           (fun elt weight cumu ->
             let cumu = Reals.add cumu weight in
             if Reals.(r <= cumu) then raise (Sampled elt) else cumu)
           P.weightmap
           Reals.zero
       in
       assert false
     with Sampled x -> x

  let density (type t)
      (module V : Basic_intf.Free_module
        with type R.t = Reals.t
         and type basis = t) (elements : (t * Reals.t) list) : t fin_den =
    (module struct
      type nonrec t = t

      module V = V

      let weightmap = V.of_list elements
    end)

  let probability (type t)
      (module V : Basic_intf.Free_module
        with type R.t = Reals.t
         and type basis = t) (elements : (t * Reals.t) list) : t fin_prb =
    let (_points, weights) = List.split elements in
    let total_weight = List.fold_left Reals.add Reals.zero weights in
    if Reals.compare total_weight Reals.one <> 0 then
      invalid_arg "Stats.probability: weights do not sum up to 1" ;
    density (module V) elements

  let total_mass (type t) ((module D) : t fin_den) : Reals.t =
    D.V.fold (fun _ w acc -> Reals.add w acc) D.weightmap Reals.zero

  let normalize (type t) ((module D) : t fin_den) : t fin_prb =
    let mass = total_mass (module D) in
    let imass = Reals.div Reals.one mass in
    (module struct
      type t = D.t

      module V = D.V

      let weightmap = V.smul imass D.weightmap
    end)

  let fin_prb_of_empirical (type t)
      (module V : Basic_intf.Free_module
        with type R.t = Reals.t
         and type basis = t) (p : t array) : t fin_den =
    let weightmap =
      Array.fold_left
        (fun vec elt -> V.add vec (V.of_list [(elt, Reals.one)]))
        V.zero
        p
    in
    let density : t fin_den =
      (module struct
        type nonrec t = t

        module V = V

        let weightmap = weightmap
      end)
    in
    normalize density

  let uniform (type t)
      (module V : Basic_intf.Free_module
        with type R.t = Reals.t
         and type basis = t) (arr : t array) : t fin_prb =
    let len = Array.length arr in
    if Int.equal len 0 then failwith "uniform: empty array"
    else
      let prb = Reals.(div one (of_int len)) in
      (module struct
        type nonrec t = t

        module V = V

        let weightmap =
          Array.fold_left
            (fun map x -> V.add (V.of_list [(x, prb)]) map)
            V.zero
            arr
      end)

  let eval_prb (type t) ((module P) : t fin_prb) (x : t) : Reals.t =
    P.V.eval P.weightmap x

  let integrate (type t) ((module P) : t fin_prb) (f : t -> Reals.t) : Reals.t =
    P.V.fold (fun x w acc -> Reals.(acc + (w * f x))) P.weightmap Reals.zero

  let kernel (type a b) ?(h : (module Basic_intf.Std with type t = a) option)
      (module V : Basic_intf.Free_module
        with type R.t = Reals.t
         and type basis = b) (kernel : a -> (b * Reals.t) list) : (a, b) kernel
      =
    let kernel =
      match h with
      | None -> fun x -> V.of_list (kernel x)
      | Some (module H) -> (
          let module Element = struct
            type t = { key : H.t; data : V.t }

            let hash { key; _ } = H.hash key

            let equal x1 x2 = H.equal x1.key x2.key
          end in
          let module Table = Weak.Make (Element) in
          let table = Table.create 11 in
          fun x ->
            match Table.find_opt table { Element.key = x; data = V.zero } with
            | None ->
                let res = V.of_list (kernel x) in
                Table.add table { Element.key = x; data = res } ;
                res
            | Some { Element.data; _ } -> data)
    in
    let module K = struct
      type t = a

      type u = b

      module V = V

      let kernel = kernel
    end in
    (module K)

  let compose :
      type a b c.
      ?h:(module Basic_intf.Std with type t = a) ->
      (a, b) kernel ->
      (b, c) kernel ->
      (a, c) kernel =
   fun ?h (module K1) (module K2) ->
    let kernel =
      match h with
      | None ->
          fun x ->
            let vec = K1.kernel x in
            K1.V.fold
              (fun b pb acc ->
                let vec = K2.kernel b in
                K2.V.add (K2.V.smul pb vec) acc)
              vec
              K2.V.zero
      | Some (module H) -> (
          let module Element = struct
            type t = { key : H.t; data : K2.V.t }

            let hash { key; _ } = H.hash key

            let equal x1 x2 = H.equal x1.key x2.key
          end in
          let module Table = Weak.Make (Element) in
          let table = Table.create 11 in
          fun x ->
            match
              Table.find_opt table { Element.key = x; data = K2.V.zero }
            with
            | None ->
                let vec = K1.kernel x in
                let res =
                  K1.V.fold
                    (fun b pb acc ->
                      let vec = K2.kernel b in
                      K2.V.(add (smul pb vec) acc))
                    vec
                    K2.V.zero
                in
                Table.add table { Element.key = x; data = res } ;
                res
            | Some { Element.data; _ } -> data)
    in
    let module Kernel = struct
      type t = K1.t

      type u = K2.u

      module V = K2.V

      let kernel = kernel
    end in
    (module Kernel)

  let constant_kernel : type a b. b fin_prb -> (a, b) kernel =
   fun (module Prb) ->
    let module Kernel = struct
      type t = a

      type u = Prb.t

      module V = Prb.V

      let kernel _x = Prb.weightmap
    end in
    (module Kernel)

  let eval_kernel : type a b. a -> (a, b) kernel -> (b * Reals.t) list =
   fun x (module K) -> K.V.fold (fun k p acc -> (k, p) :: acc) (K.kernel x) []

  let raw_data_density (type t) ((module D) : t fin_den) =
    let den = D.V.fold (fun elt w acc -> (elt, w) :: acc) D.weightmap [] in
    `Density den

  let raw_data_probability (type t) ((module D) : t fin_prb) =
    let den = D.V.fold (fun elt w acc -> (elt, w) :: acc) D.weightmap [] in
    `Probability den

  let pp_fin_fun :
      (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a fin_den -> unit
      =
   fun kf f den ->
    let (`Density l) = raw_data_density den in
    Format.fprintf
      f
      "@[<h>%a@]"
      (Format.pp_print_list (fun elt_fmt (elt, pr) ->
           Format.fprintf elt_fmt "(%a, %a);@," kf elt Reals.pp pr))
      l

  let pushforward (type t u) ~(prior : t fin_fun) ~(likelihood : (t, u) kernel)
      : u fin_prb =
    let (module Prior) = prior in
    let (module Likelihood) = likelihood in
    let map =
      Prior.V.fold
        (fun x px acc ->
          let fx = Likelihood.kernel x in
          Likelihood.V.(add (smul px fx) acc))
        Prior.weightmap
        Likelihood.V.zero
    in
    let module Result = struct
      type t = u

      module V = Likelihood.V

      let weightmap = map
    end in
    (module Result)

  let inverse (type t u) ?(h : (module Basic_intf.Std with type t = u) option)
      (prior : t fin_prb) (likelihood : (t, u) kernel) :
      u fin_prb * (u, t) kernel =
    let (module Prior) = prior in
    let (module Likelihood) = likelihood in
    let (module Pushforward) = pushforward ~prior ~likelihood in
    let kernel (y : u) =
      let nu_y = Pushforward.V.eval Pushforward.weightmap y in
      Prior.V.fold
        (fun x mu_x acc ->
          let forward = Likelihood.kernel x in
          let f_x_y = Likelihood.V.eval forward y in
          let prob = Reals.(mul mu_x f_x_y / nu_y) in
          Prior.V.(add acc (of_list [(x, prob)])))
        Prior.weightmap
        Prior.V.zero
    in
    let module Kernel = struct
      type t = Likelihood.u

      type u = Likelihood.t

      module V = Prior.V

      let kernel =
        match h with
        | None -> kernel
        | Some (module H) -> (
            let module Element = struct
              type t = { key : H.t; data : V.t }

              let hash { key; _ } = H.hash key

              let equal x1 x2 = H.equal x1.key x2.key
            end in
            let module Table = Weak.Make (Element) in
            let table = Table.create 11 in
            fun y ->
              match
                Table.find_opt table { Element.key = y; data = Prior.V.zero }
              with
              | None ->
                  let res = kernel y in
                  Table.add table { Element.key = y; data = res } ;
                  res
              | Some res -> res.data)
    end in
    ((module Pushforward), (module Kernel))

  module Bool_vec =
    Basic_impl.Free_module.Make (Std.Bool) (Reals) (Basic_impl.Bool_map)
  module Int_vec =
    Basic_impl.Free_module.Make (Std.Int) (Reals) (Basic_impl.Int_map)

  let coin ~bias : bool fin_prb =
    if Reals.(bias < zero || bias > one) then
      failwith "Stats.coin: invalid bias"
    else density (module Bool_vec) [(true, bias); (false, Reals.(one - bias))]

  let bincoeff n k =
    let n = Reals.of_int n in
    let rec loop i acc =
      if Int.equal i (k + 1) then acc
      else
        let fi = Reals.of_int i in
        loop (i + 1) Reals.(acc * ((n + one - fi) / fi))
    in
    loop 1 Reals.one

  let binomial (coin : bool fin_prb) n =
    let p = eval_prb coin true in
    let not_p = eval_prb coin false in
    let elements =
      List.init n (fun k ->
          let n_minus_k = n - k in
          Reals.(k, bincoeff n k * npow p k * npow not_p n_minus_k))
    in
    density (module Int_vec) elements

  let mean_generic (type elt)
      (module L : Basic_intf.Module with type t = elt and type R.t = Reals.t)
      ((module Dist) : elt fin_fun) =
    Dist.V.fold (fun x w acc -> L.add (L.smul w x) acc) Dist.weightmap L.zero

  let mean ((module Dist) : reals fin_fun) =
    integrate (module Dist) (fun x -> x)

  let variance ((module Dist) : reals fin_fun) =
    let m = mean (module Dist) in
    Dist.V.fold
      (fun x w acc ->
        let open Reals in
        let delta = x - m in
        let delta_squared = delta * delta in
        acc + (delta_squared * w))
      Dist.weightmap
      Reals.zero
end
OCaml

Innovation. Community. Security.