package prbnmcn-dagger-test

  1. Overview
  2. Docs

Source file resampling_test.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
(** Test resampling (used in SMC) *)

open Dagger

(** {2 Helpers} *)

let initial_state =
  [| 0x1337533D;
     71287309;
     666932349;
     719132214;
     461480042;
     387006837;
     443018964;
     450865457;
     901711679;
     833353016;
     397060904;
     811875353
  |]

let state = RNG.make (Array.copy initial_state)

let pp_arr pp fmtr arr =
  let open Format in
  let xs = Array.to_list arr in
  pp_print_list ~pp_sep:(fun fmtr () -> fprintf fmtr ",") pp fmtr xs

module Rat = struct
  include Basic_structures.Basic_impl.Reals.Rational

  let pp fmtr q = Format.fprintf fmtr "%f" (Q.to_float q)
end

(** Instantiate resampling over the field of (arbitrary precision) rationals *)
let uniform x =
  let f = Q.to_float x in
  fun state -> Q.of_float (RNG.float state f)

(** Define a Q-valued measure generator for QCheck *)
module Dist = struct
  type 'a t =
    { mutable active : ('a * Q.t) list; mutable suspended : ('a * Q.t) list }

  let pp pp_elt fmtr { active; suspended } =
    let open Format in
    let pp_mes fmtr list =
      pp_print_list
        ~pp_sep:(fun fmtr () -> fprintf fmtr ",")
        (fun fmtr (x, q) -> fprintf fmtr "(%a,%a)" pp_elt x Rat.pp q)
        fmtr
        list
    in
    fprintf fmtr "active=%a, suspended=%a" pp_mes active pp_mes suspended

  (* [simpl] removes duplicates in the support of the measure and
     canonicalizes it.

     We use [Stdlib.compare] which is not pretty. *)
  let simpl : ('a * Q.t) list -> ('a * Q.t) list =
   fun l ->
    let rec loop l =
      match l with
      | [] -> []
      | [_] -> l
      | ((x1, w1) as hd) :: ((x2, w2) :: tl as rest) ->
          if Stdlib.( = ) x1 x2 then loop ((x1, Q.add w1 w2) :: tl)
          else hd :: loop rest
    in
    let l = List.sort (fun (x, _) (y, _) -> Stdlib.compare x y) l in
    loop l

  let gen_list : 'a QCheck.Gen.t -> ('a * Q.t) list QCheck.Gen.t =
   fun gen ->
    let open QCheck.Gen in
    let nonzero = small_nat >|= ( + ) 2 in
    nonzero >>= fun length ->
    list_size (return length) gen >>= fun list ->
    list_size (return length) nonzero >>= fun wlist ->
    let wlist = List.map Q.of_int wlist in
    let total = List.fold_left Q.add Q.zero wlist in
    let plist = List.map (fun w -> Q.(w / total)) wlist in
    return (List.combine list plist)

  let gen : 'a QCheck.Gen.t -> 'a t QCheck.Gen.t =
   fun gen ->
    let open QCheck.Gen in
    gen_list gen >|= fun suspended -> { active = []; suspended }

  let flip { active; suspended } = { suspended = active; active = suspended }

  let copy { active; suspended } = { active; suspended }
end

(* ------------------------------------------------------------------------- *)
(* Check that the measure generator generates normalized measures. *)

let is_generator_normalized =
  QCheck.Test.make
    ~count:1000
    ~name:"is_generator_normalized"
    (QCheck.make
       QCheck.Gen.(Dist.gen small_nat >|= fun x -> List.map snd x.suspended))
    (fun proba ->
      let total = List.fold_left Q.add Q.zero proba in
      if Q.(total <> one) then
        QCheck.Test.fail_reportf
          "expected normalized measure, got %f (measure = %a)"
          (Q.to_float total)
          (Format.pp_print_list
             ~pp_sep:(fun fmtr () -> Format.fprintf fmtr ",")
             (fun fmtr q -> Format.fprintf fmtr "%f" (Q.to_float q)))
          proba ;
      true)

(* ---------------------------------------------------------------- *)

module R =
  Resampling.Make_predefined
    (Rat)
    (struct
      let uniform = uniform
    end)

let environment (type a) (pop : a Dist.t) : (a, Q.t) Resampling.particles =
  (module struct
    type p = a * Q.t

    type o = a

    type r = Q.t

    let get_output _ = None

    let get_score (_, s) = s

    let iter f = List.iter (fun ((_, score) as p) -> f p score) pop.suspended

    let fold f acc =
      List.fold_left
        (fun acc ((_, score) as p) -> f acc p score)
        acc
        pop.suspended

    let append ((a, _) : p) score = pop.active <- (a, score) :: pop.active

    let total () =
      List.fold_left
        (fun acc (_, score) -> Q.add acc score)
        Q.zero
        pop.suspended

    let size () = List.length pop.suspended

    let ess () = Q.zero
  end)

exception Invalid_population

let list_empty = function [] -> true | _ -> false

let resample mu
    (resampling : ('a, Q.t) Resampling.particles -> unit -> RNG.t -> unit)
    rng_state =
  resampling (environment mu) () rng_state

let total_mass (type a) (mu : a Dist.t) =
  let (module E) = environment mu in
  E.total ()

let cardinal (type a) (mu : a Dist.t) =
  let (module E) = environment mu in
  E.size ()

(* ------------------------------------------------------------------------- *)
(* Test stratified resampling. *)

let isum a = Array.fold_left ( + ) 0 a

let iterative_resampling_generic ?(state = state) mu f =
  resample mu (fun env () rng -> R.resampling_generic_iterative f env rng) state

(* iterative implementation *)
let iterative_stratified_resampling ?(state = state) mu =
  resample
    mu
    (R.stratified_resampling
       ~ess_threshold:Q.one
       ~target_size:(List.length mu.suspended))
    state

let iter ?(pp = fun _ _ -> ()) ?(msg = "") f mu0 =
  let total = total_mass mu0 in
  let card = cardinal mu0 in
  let zero_mass_elements =
    List.filter_map
      (fun (x, q) -> if Q.equal q Q.zero then Some x else None)
      (Dist.simpl mu0.suspended)
  in
  let mu1 = Dist.copy mu0 in
  Format.printf "before %a@." (Dist.pp pp) mu1 ;
  f mu1 ;
  Format.printf "after %a@." (Dist.pp pp) mu1 ;
  let mu1 = Dist.flip mu1 in
  let total' = total_mass mu1 in
  let card' = cardinal mu1 in
  if total <> total' then
    QCheck.Test.fail_reportf
      "%s total mass not preserved (%f vs %f)"
      msg
      (Q.to_float total)
      (Q.to_float total')
  else if card <> card' then
    QCheck.Test.fail_reportf
      "%s cardinality not preserved (%a vs %a)"
      msg
      (Dist.pp pp)
      mu0
      (Dist.pp pp)
      mu1
  else
    let has_zero_elements =
      List.exists (fun (x, _) -> List.mem x zero_mass_elements) mu1.suspended
    in
    if has_zero_elements then
      QCheck.Test.fail_reportf
        "%s has zero mass elements (%a)"
        msg
        (Dist.pp pp)
        mu0
        (Dist.pp pp)
        mu1

let test_stratified_on_handcrafted =
  let rand = Q.of_float 0.12 in
  let f i _rng_state = Q.add rand (Q.div (Q.of_int i) (Q.of_int 7)) in
  let resampling mu = iterative_resampling_generic mu f in
  QCheck.Test.make
    ~name:"test_stratified_on_handcrafted"
    ~count:1
    (QCheck.make
       QCheck.Gen.(
         return
           [ (0, Q.zero);
             (1, Q.zero);
             (2, Q.zero);
             (3, Q.zero);
             (4, Q.zero);
             (5, Q.(1 // 3));
             (6, Q.(2 // 3)) ]))
    (fun proba ->
      let measure = { Dist.suspended = proba; active = [] } in
      iter
        ~pp:Format.pp_print_int
        ~msg:"test_stratified_on_handcrafted"
        resampling
        measure ;
      let support = measure.active |> List.map fst in
      let correct = List.for_all (fun i -> List.mem i [5; 6]) support in
      if not correct then
        QCheck.Test.fail_reportf
          "invalid support (measure = %a)"
          (Dist.pp Format.pp_print_int)
          measure ;
      true)

let test_iterative_stratified =
  QCheck.Test.make
    ~count:1000
    ~name:"test_iterative_stratified"
    (QCheck.make QCheck.Gen.(Dist.gen small_nat))
    (fun measure ->
      assert (List.length measure.Dist.suspended >= 2) ;
      iter
        ~pp:Format.pp_print_int
        ~msg:"test_iterative_stratified"
        iterative_stratified_resampling
        measure ;
      true)

(* ------------------------------------------------------------------------- *)

let tests =
  [ is_generator_normalized;
    test_stratified_on_handcrafted;
    test_iterative_stratified ]
OCaml

Innovation. Community. Security.