package prbnmcn-dagger-gsl

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file gsl_dist.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
module Log_space = Dagger.Log_space

module type GSL_SIG = sig
  module Rng : sig
    type rng_type

    type t

    val default : unit -> rng_type

    val make : rng_type -> t

    val set : t -> nativeint -> unit

    val uniform_int : t -> int -> int
  end

  module Randist : sig
    val flat : Rng.t -> a:float -> b:float -> float

    val flat_pdf : float -> a:float -> b:float -> float

    val bernoulli : Rng.t -> p:float -> int

    val bernoulli_pdf : int -> p:float -> float

    val gaussian : Rng.t -> sigma:float -> float

    val gaussian_pdf : float -> sigma:float -> float

    val gaussian_tail : Rng.t -> a:float -> sigma:float -> float

    val gaussian_tail_pdf : float -> a:float -> sigma:float -> float

    val laplace : Rng.t -> a:float -> float

    val laplace_pdf : float -> a:float -> float

    val exppow : Rng.t -> a:float -> b:float -> float

    val exppow_pdf : float -> a:float -> b:float -> float

    val cauchy : Rng.t -> a:float -> float

    val cauchy_pdf : float -> a:float -> float

    val rayleigh : Rng.t -> sigma:float -> float

    val rayleigh_pdf : float -> sigma:float -> float

    val rayleigh_tail : Rng.t -> a:float -> sigma:float -> float

    val rayleigh_tail_pdf : float -> a:float -> sigma:float -> float

    val landau : Rng.t -> float

    val landau_pdf : float -> float

    val gamma : Rng.t -> a:float -> b:float -> float

    val gamma_pdf : float -> a:float -> b:float -> float

    val weibull : Rng.t -> a:float -> b:float -> float

    val weibull_pdf : float -> a:float -> b:float -> float

    val binomial : Rng.t -> p:float -> n:int -> int

    val binomial_pdf : int -> p:float -> n:int -> float

    val geometric : Rng.t -> p:float -> int

    val geometric_pdf : int -> p:float -> float

    val exponential : Rng.t -> mu:float -> float

    val exponential_pdf : float -> mu:float -> float

    val poisson : Rng.t -> mu:float -> int

    val poisson_pdf : int -> mu:float -> float

    type discrete

    val discrete_preproc : float array -> discrete

    val discrete : Rng.t -> discrete -> int

    val discrete_pdf : int -> discrete -> float

    val beta : Rng.t -> a:float -> b:float -> float

    val beta_pdf : float -> a:float -> b:float -> float

    val dirichlet : Rng.t -> alpha:float array -> theta:float array -> unit

    val dirichlet_pdf : alpha:float array -> theta:float array -> float

    val dirichlet_lnpdf : alpha:float array -> theta:float array -> float

    val lognormal : Rng.t -> zeta:float -> sigma:float -> float

    val lognormal_pdf : float -> zeta:float -> sigma:float -> float

    val chisq : Rng.t -> nu:float -> float

    val chisq_pdf : float -> nu:float -> float
  end
end

module Make (Gsl : GSL_SIG) = struct
  open Gsl

  let gsl_rng = ref (Rng.default ())

  let rng (s : Random.State.t) =
    let rng = Rng.make !gsl_rng in
    let seed = Random.State.nativebits s in
    Rng.set rng seed ;
    rng

  open Dagger.Dist

  let dist0 sampler log_pdf =
    stateless (fun state -> sampler (rng state)) log_pdf
    [@@inline]

  let dist1 sampler log_pdf arg =
    stateless
      (fun rng_state -> sampler arg (rng rng_state))
      (fun x -> log_pdf arg x)
    [@@inline]

  let dist2 sampler log_pdf arg1 arg2 =
    stateless
      (fun rng_state -> sampler arg1 arg2 (rng rng_state))
      (fun x -> log_pdf arg1 arg2 x)
    [@@inline]

  let kernel1 sampler log_pdf start arg =
    kernel
      start
      (fun x rng_state -> sampler arg x (rng rng_state))
      (fun x y -> log_pdf arg x y)
    [@@inline]

  let float bound =
    dist1
      (fun b state -> Randist.flat state ~a:0.0 ~b)
      (fun b x -> Log_space.of_float (Randist.flat_pdf x ~a:0.0 ~b))
      bound

  let int bound =
    dist1
      (fun b state -> Rng.uniform_int state b)
      (fun b x ->
        if x < 0 || x >= b then Log_space.zero
        else Log_space.of_float (1. /. float_of_int b))
      bound

  let bool =
    let ll = Log_space.of_float 0.5 in
    dist0 (fun state -> Randist.bernoulli state ~p:0.5) (fun _ -> ll)

  let gaussian ~mean ~std =
    dist2
      (fun mean sigma state -> mean +. Randist.gaussian state ~sigma)
      (fun mean sigma x ->
        Log_space.of_float (Randist.gaussian_pdf (mean +. x) ~sigma))
      mean
      std

  let gaussian_tail ~a ~std =
    dist2
      (fun a sigma state -> Randist.gaussian_tail state ~a ~sigma)
      (fun a sigma x ->
        Log_space.of_float (Randist.gaussian_tail_pdf x ~a ~sigma))
      a
      std

  let laplace ~a =
    dist1
      (fun a state -> Randist.laplace state ~a)
      (fun a x -> Log_space.of_float (Randist.laplace_pdf x ~a))
      a

  let exppow ~a ~b =
    dist2
      (fun a b state -> Randist.exppow state ~a ~b)
      (fun a b x -> Log_space.of_float (Randist.exppow_pdf x ~a ~b))
      a
      b

  let cauchy ~a =
    dist1
      (fun a state -> Randist.cauchy state ~a)
      (fun a x -> Log_space.of_float (Randist.cauchy_pdf x ~a))
      a

  let rayleigh ~sigma =
    dist1
      (fun sigma state -> Randist.rayleigh state ~sigma)
      (fun sigma x -> Log_space.of_float (Randist.rayleigh_pdf x ~sigma))
      sigma

  let rayleigh_tail ~a ~sigma =
    dist2
      (fun a sigma state -> Randist.rayleigh_tail state ~a ~sigma)
      (fun a sigma x ->
        Log_space.of_float (Randist.rayleigh_tail_pdf x ~a ~sigma))
      a
      sigma

  let landau =
    dist0 Randist.landau (fun x -> Log_space.of_float (Randist.landau_pdf x))

  let gamma ~a ~b =
    dist2
      (fun a b state -> Randist.gamma state ~a ~b)
      (fun a b x -> Log_space.of_float (Randist.gamma_pdf x ~a ~b))
      a
      b

  let weibull ~a ~b =
    dist2
      (fun a b state -> Randist.weibull state ~a ~b)
      (fun a b x -> Log_space.of_float (Randist.weibull_pdf x ~a ~b))
      a
      b

  let flat a b =
    dist2
      (fun a b state -> Randist.flat state ~a ~b)
      (fun a b x -> Log_space.of_float (Randist.flat_pdf x ~a ~b))
      a
      b

  let bernoulli ~bias =
    dist1
      (fun p state -> Randist.bernoulli state ~p = 1)
      (fun p x ->
        if x then Log_space.of_float p else Log_space.of_float (1. -. p))
      bias

  let binomial p n =
    dist2
      (fun p n state -> Randist.binomial state ~p ~n)
      (fun p n x -> Log_space.of_float (Randist.binomial_pdf x ~p ~n))
      p
      n

  let geometric ~p =
    dist1
      (fun p state -> Randist.geometric state ~p)
      (fun p k -> Log_space.of_float (Randist.geometric_pdf k ~p))
      p

  let exponential ~rate =
    dist1
      (fun mu state -> Randist.exponential state ~mu)
      (fun mu x -> Log_space.of_float (Randist.exponential_pdf x ~mu))
      rate

  let poisson ~rate =
    dist1
      (fun mu state -> Randist.poisson state ~mu)
      (fun mu x -> Log_space.of_float (Randist.poisson_pdf x ~mu))
      rate

  let categorical (type a) (module H : Hashtbl.S with type key = a)
      (cases : (a * float) array) =
    let xs = Array.map fst cases in
    let ps = Array.map snd cases in
    let contents = Array.mapi (fun i x -> (x, i)) xs in
    let table = H.of_seq (Array.to_seq contents) in
    let sampler = Randist.discrete_preproc ps in
    dist0
      (fun state ->
        let index = Randist.discrete state sampler in
        xs.(index))
      (fun elt ->
        match H.find_opt table elt with
        | None -> assert false
        | Some i -> Log_space.of_float (Randist.discrete_pdf i sampler))

  let beta ~a ~b =
    dist2
      (fun a b state -> Randist.beta state ~a ~b)
      (fun a b x -> Log_space.of_float (Randist.beta_pdf ~a ~b x))
      a
      b

  let dirichlet ~alpha =
    dist1
      (fun alpha state ->
        let theta = Array.make (Array.length alpha) 0.0 in
        Randist.dirichlet state ~alpha ~theta ;
        theta)
      (fun alpha theta ->
        Log_space.unsafe_cast (Randist.dirichlet_lnpdf ~alpha ~theta))
      alpha

  let lognormal ~zeta ~sigma =
    dist2
      (fun zeta sigma state -> Randist.lognormal state ~zeta ~sigma)
      (fun zeta sigma x ->
        Log_space.of_float (Randist.lognormal_pdf x ~zeta ~sigma))
      zeta
      sigma

  let chi_squared ~nu =
    dist1
      (fun nu state -> Randist.chisq state ~nu)
      (fun nu x -> Log_space.of_float (Randist.chisq_pdf x ~nu))
      nu

  let mixture coeffs (dists : 'a t array) =
    let dists =
      Array.map
        (function Stateless d -> d | Kernel _ -> invalid_arg "mixture")
        dists
    in
    if Array.length coeffs <> Array.length dists then invalid_arg "mixture" ;
    if Array.length coeffs = 0 then invalid_arg "mixture" ;
    let log_coeffs = Array.map Log_space.of_float coeffs in
    let sampler = Randist.discrete_preproc coeffs in
    let log_pdf x =
      let open Log_space in
      let acc = ref one in
      for i = 0 to Array.length dists - 1 do
        acc := mul !acc (mul log_coeffs.(i) (dists.(i).ll x))
      done ;
      !acc
    in
    let sampler rng_state =
      let case = Randist.discrete (rng rng_state) sampler in
      dists.(case).sample rng_state
    in
    stateless sampler log_pdf
end
OCaml

Innovation. Community. Security.