package prbnmcn-dagger-test

  1. Overview
  2. Docs

Source file poly.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
open Dagger.Smc_inference

module Poly = struct
  type t = float array

  let zero () = [| 0.0 |]

  let degree p = Array.length p - 1

  let get p i = if i >= Array.length p then 0.0 else p.(i)

  let add p1 p2 =
    let len = Int.max (Array.length p1) (Array.length p2) in
    Array.init len (fun i -> get p1 i +. get p2 i)

  let smul s a = Array.map (fun x -> x *. s) a

  let eval p x =
    let c = ref 0.0 in
    for i = 0 to Array.length p - 1 do
      c := !c +. (p.(i) *. (x ** float i))
    done ;
    !c

  let init deg f = Array.init (deg + 1) f

  let truncate deg p = Array.init (deg + 1) (fun i -> get p i)

  let pp fmtr (p : t) =
    let first = ref true in
    for i = 0 to Array.length p - 1 do
      let coeff = p.(i) in
      if Float.equal coeff 0.0 then ()
      else
        let sep =
          if !first then (
            first := false ;
            "")
          else " + "
        in
        if i = 0 then Format.fprintf fmtr "%f" p.(i)
        else if i = 1 then Format.fprintf fmtr "%s%f x" sep p.(i)
        else Format.fprintf fmtr "%s%f x^%d" sep p.(i) i
    done
end

module Smc_types = struct
  type particle_output = Poly.t

  type resampling_state = unit
end

module Smc = Make (Smc_types) ()
module Dists = Stats.Gen

(* A random walk on polynomials *)
let mutate (p : Poly.t) =
  let open Smc in
  let open Infix in
  let current_degree = Poly.degree p in
  let* degree =
    sample
      (Dists.uniform
         [| Int.max 0 (current_degree - 1);
            current_degree;
            current_degree + 1
         |])
  in
  let* noise =
    map_array
      (Poly.init degree (fun _ -> sample (Dists.gaussian ~mean:0.0 ~std:1.0)))
      Fun.id
  in
  return (Poly.add noise p |> Poly.truncate degree)

let model observations =
  let open Smc in
  let open Infix in
  let rec loop observed acc prev_coeffs =
    match observed with
    | [] -> return ()
    | next :: ys ->
        let* coeffs = mutate prev_coeffs in
        let acc = next :: acc in
        (* Score the quality of the fit *)
        let* () =
          List_ops.iter
            (fun (x, y) ->
              let estimate = Poly.eval coeffs x in
              log_score @@ Stats_dist.Pdfs.gaussian_ln ~mean:y ~std:1.0 estimate)
            acc
        in
        (* Penalize high-degree polynomials *)
        let* () =
          log_score
          @@ Stats_dist.Pdfs.exponential_ln
               ~rate:0.5
               (float (Poly.degree coeffs))
        in
        let* () = yield coeffs in
        loop ys acc coeffs
  in
  loop observations [] (Poly.zero ())

let run_model observations rng_state =
  Smc.run
    (systematic_resampling ~ess_threshold:0.5)
    ()
    ~npart:10_000
    (model observations)
    rng_state
  |> Seq.filter_map (fun pop ->
         if Array.length pop.Smc.active = 0 then None
         else
           let itotal = 1. /. pop.total_mass in
           Array.fold_left
             (fun acc (coeff, w) -> Poly.(add acc (smul (w *. itotal) coeff)))
             (Poly.zero ())
             pop.active
           |> Option.some)
  |> Seq.memoize

let coeffs = [| 3.0; 25.0; -8.; 0.5 |]

let noisy_observations rng_state =
  List.init 150 (fun i ->
      let x = 0.1 *. float i in
      (x, Stats.Gen.gaussian ~mean:(Poly.eval coeffs x) ~std:10.0 rng_state))

let rng_state = Random.State.make [| 149572; 3891981; 3847844 |]

let observations = noisy_observations rng_state

let plot obs coeffs =
  let open Plot in
  if not Helpers.produce_artifacts then ()
  else
    let line2d points = Line.line_2d ~points:(points |> Data.of_list) () in
    let reference =
      Line.line_2d ~legend:"obs" ~points:(obs |> Data.of_list) ()
    in
    let plots = reference :: List.map line2d coeffs in
    run
      ~target:(png ~pixel_size:(1280, 1024) ~png_file:"smc_poly_obs.png" ())
      exec
      (plot2
         ~xaxis:"x"
         ~yaxis:"y"
         ~xrange:(Range.make ~min:0.0 ~max:15.0 ())
         ~yrange:(Range.make ~min:~-.70. ~max:70. ())
         plots)

let run () =
  let plot_predicted coeffs =
    List.map (fun (x, _) -> (x, Poly.eval coeffs x)) observations
  in
  let coeffs = run_model (noisy_observations rng_state) rng_state in
  let predicted =
    coeffs
    |> Seq.mapi (fun i elt -> (i, elt))
    |> Seq.filter (fun (i, _) -> i mod 10 = 0)
    |> Seq.map snd |> Seq.map plot_predicted |> List.of_seq
  in
  plot observations predicted ;
  Seq.iteri (fun i coeff -> Format.printf "%d, %a@." i Poly.pp coeff) coeffs

let tests =
  [ ( QCheck.Test.make ~name:"smc-poly-fit" ~count:1 QCheck.unit @@ fun () ->
      run () ;
      true ) ]
OCaml

Innovation. Community. Security.