package frenetic

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file Vlr.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
open Core

module type HashCmp = sig
  type t [@@deriving sexp, compare, eq, hash]
  (* val pp : Format.formatter -> t -> unit *)
  val to_string : t -> string
end

module type Lattice = sig
  include HashCmp
  val subset_eq : t -> t -> bool
end

module type Result = sig
  include HashCmp
  val sum : t -> t -> t
  val prod : t -> t -> t
  val one : t
  val zero : t
end

module IntPair = struct
  type t = (int * int) [@@deriving sexp, compare]
  let hash (t1, t2) = 617 * t1 +  619 * t2
end

module IntPairTbl = Hashtbl.Make(IntPair)

module Make(V:HashCmp)(L:Lattice)(R:Result) = struct
  type v = V.t * L.t [@@deriving sexp, compare, hash]
  type r = R.t [@@deriving sexp, compare, hash]

  type d
    = Leaf of r
    | Branch of {
        test : v;
        tru : int;
        fls : int;
        all_fls : int [@compare.ignore] (* implies [@hash.ignore] *);
      }
    [@@deriving sexp, compare, hash]
  (* A tree structure representing the decision diagram. The [Leaf] variant
   * represents a constant function. The [Branch(v, l, t, f)] represents an
   * if-then-else. When variable [v] takes on the value [l], then [t] should
   * hold. Otherwise, [f] should hold.
   *
   * [Branch] nodes appear in an order determined first by the total order on
   * the [V.t] value with with ties broken by the total order on [L.t]. The
   * least such pair should appear at the root of the diagram, with each child
   * nodes being strictly greater than their parent node. This invariant is
   * important both for efficiency and correctness.
   * *)

  module T = struct
    type t = int [@@deriving sexp, compare, eq]
  end
  include T
  include Comparator.Make(T)

  module D = Frenetic_kernel.Hashcons.Make(struct
    type t = d [@@deriving sexp, compare, hash]
  end)

  let get = D.get
  let unget = D.unget
  let get_uid (t:t) : int = t

  module Tbl = Int.Table
  module BinTbl = IntPairTbl

  let mk_leaf r = D.get (Leaf r)

  let mk_branch ((v,l) as test) tru fls =
    (* When the ids of the diagrams are equal, then the diagram will take on the
       same value regardless of variable assignment. The node that's being
       constructed can therefore be eliminated and replaced with one of the
       sub-diagrams, which are identical.

       If the ids are distinct, then the node has to be constructed and assigned
       a new id. *)
    if equal tru fls then
      fls
    else match unget fls with
    | Branch { test = (v',_); all_fls; _ } when Poly.(v=v') ->
      if all_fls = tru then
        fls
      else
        D.get (Branch { test; tru; fls; all_fls })
    | _ ->
      D.get (Branch { test; tru; fls; all_fls = fls})

  let unchecked_cond = mk_branch

  let drop = mk_leaf (R.zero)
  let id = mk_leaf (R.one)

  let rec to_string t =
    if t = drop then "0" else
    if t = id then "1" else
    match D.unget t with
    | Leaf r ->
      Printf.sprintf "%s" (R.to_string r)
    | Branch { test = (v, l); tru = t; fls = f } ->
      Printf.sprintf "(%s = %s ? %s : %s)"
      (V.to_string v) (L.to_string l) (to_string t) (to_string f)


  let rec fold ~f ~g t = match D.unget t with
    | Leaf r -> f r
    | Branch { test = (v, l);  tru; fls } ->
      g (v, l) (fold ~f ~g tru) (fold ~f ~g fls)

  let const r = mk_leaf r
  let atom (v,l) t f = mk_branch (v,l) (const t) (const f)

  let rec map_r ~f t = fold t
    ~f:(fun r -> const (f r))
    ~g:(fun (v, l) tru fls -> mk_branch (v,l) tru fls)

  let restrict lst u =
    let rec loop xs u =
      match xs, D.unget u with
      | []          , _
      | _           , Leaf _ -> u
      | (v,l) :: xs', Branch { test = (v', l'); tru = t; fls = f } ->
        match V.compare v v' with
        |  0 -> if L.subset_eq l l' then loop xs' t else loop xs f
        | -1 -> loop xs' u
        |  1 -> mk_branch (v',l') (loop xs t) (loop xs f)
        |  _ -> assert false
    in
    loop (List.sort lst ~compare:(fun (u, _) (v, _) -> V.compare u v)) u

  let apply f zero ~(cache: (t*t, t) Hashtbl.t) =
    let rec sum x y =
      Hashtbl.find_or_add cache (x, y) ~default:(fun () -> sum' x y)
    and sum' x y =
      match D.unget x, D.unget y with
      | Leaf r, _      ->
        if R.compare r zero = 0 then y
        else map_r (fun y -> f r y) y
      | _     , Leaf r ->
        if R.compare zero r = 0 then x
        else map_r (fun x -> f x r) x
      | Branch {test=(vx, lx); tru=tx; fls=fx; all_fls=all_fls_x},
        Branch {test=(vy, ly); tru=ty; fls=fy; all_fls=all_fls_y} ->
        begin match V.compare vx vy with
        |  0 ->
          begin match L.compare lx ly with
          |  0 -> mk_branch (vx,lx) (sum tx ty) (sum fx fy)
          | -1 -> mk_branch (vx,lx) (sum tx all_fls_y) (sum fx y)
          |  1 -> mk_branch (vy,ly) (sum all_fls_x ty) (sum x fy)
          |  _ -> assert false
          end
        | -1 -> mk_branch (vx,lx) (sum tx y) (sum fx y)
        |  1 -> mk_branch (vy,ly) (sum x ty) (sum x fy)
        |  _ -> assert false
        end
    in sum

  let sum_tbl : (t*t, t) Hashtbl.t = BinTbl.create ~size:1000 ()
  let sum = apply R.sum R.zero ~cache:sum_tbl

  let prod_tbl : (t*t, t) Hashtbl.t = BinTbl.create ~size:1000 ()
  let prod = apply R.prod R.one ~cache:prod_tbl

  let childreen t =
    let rec loop t acc =
      match unget t with
      | Leaf _ -> acc
      | Branch { tru=l; fls=r } ->
        l::r::acc
        |> loop l
        |> loop r
    in
    loop t []

  let clear_cache ~(preserve : Int.Set.t) =
    (* SJS: the interface exposes `id` and `drop` as constants,
       so they must NEVER be cleared from the cache *)
    let preserve =
      Int.Set.(add (add preserve drop) id)
      |> fun init -> Int.Set.fold init ~init ~f:(fun init root ->
        List.fold (childreen root) ~init ~f:Int.Set.add
      )
    in
    begin
      Hashtbl.clear sum_tbl;
      Hashtbl.clear prod_tbl;
      D.clear preserve;
    end

  let cond v t f =
    let ok t =
      match unget t with
      | Leaf _ -> true
      | Branch { test = (f',v') } -> V.compare (fst v) f' = -1
    in
    if equal t f then t else
    if ok t && ok f then mk_branch v t f else
      (sum (prod (atom v R.one R.zero) t)
           (prod (atom v R.zero R.one) f))

  let map ~(f : R.t -> t)
          ~(g : V.t * L.t -> t -> t -> t)
           (t : t) : t =
    let rec map t = match unget t with
      | Leaf r -> f r
      | Branch { test=(v, l); tru; fls } -> g (v,l) (map tru) (map fls) in
    map t

  let dp_map ~(f : R.t -> t)
             ~(g : V.t * L.t -> t -> t -> t)
             (t : t)
             ~find_or_add
             : t =
    let rec map t =
      find_or_add t ~default:(fun () -> map' t)
    and map' t =
      match unget t with
        | Leaf r -> f r
        | Branch { test=(v, l); tru; fls } -> g (v,l) (map tru) (map fls) in
    map t

  let compressed_size (node : t) : int =
    let rec f (node : t) (seen : Int.Set.t) =
      if Int.Set.mem seen node then
        (0, seen)
      else
        match D.unget node with
        | Leaf _ -> (1, Int.Set.add seen node)
        | Branch { tru; fls } ->
          (* Due to variable-ordering, there is no need to add node.id to seen
             in the recursive calls *)
          let (tru_size, seen) = f tru seen in
          let (fls_size, seen) = f fls seen in
          (1 + tru_size + fls_size, Int.Set.add seen node)
    in
    f node Int.Set.empty
    |> fst

  let rec uncompressed_size (node : t) : int = match D.unget node with
    | Leaf _ -> 1
    | Branch { tru; fls } -> 1 + uncompressed_size tru + uncompressed_size fls

  let to_dot t =
    let open Format in
    let buf = Buffer.create 200 in
    let fmt = formatter_of_buffer buf in
    let seen : Int.Hash_set.t = Int.Hash_set.create ~size:10 () in
    let rank : ((V.t*L.t), Int.Hash_set.t) Hashtbl.t = Hashtbl.Poly.create ~size:20 () in
    pp_set_margin fmt (1 lsl 29);
    fprintf fmt "digraph tdk {@\n";
    let rec loop t =
      if not (Hash_set.mem seen t) then begin
        Hash_set.add seen t;
        match D.unget t with
        | Leaf r ->
          fprintf fmt "%d [shape=box label=\"%s\"];@\n" t (R.to_string r)
        | Branch { test=(v, l); tru=a; fls=b } ->
          (* FIXME: temporary hack to avoid Jane Street's annoying warnings. *)
          begin[@warning "-3"]
            try Hash_set.add (Hashtbl.find_exn rank (v, l)) t
            with Not_found | Not_found_s _ ->
              let s = Int.Hash_set.create ~size:10 () in
              Hash_set.add s t;
              Hashtbl.set rank (v, l) s
          end;
          fprintf fmt "%d [label=\"%s = %s\"];@\n"
            t (V.to_string v) (L.to_string l);
          fprintf fmt "%d -> %d;@\n" t a;
          fprintf fmt "%d -> %d [style=\"dashed\"];@\n" t b;
          loop a;
          loop b
      end
    in
    loop t;
    Hashtbl.iteri rank ~f:(fun ~key:_ ~data:s ->
      fprintf fmt "{rank=same; ";
      Hash_set.iter s ~f:(fun x -> fprintf fmt "%d " x);
      fprintf fmt ";}@\n");
    fprintf fmt "}@.";
    Buffer.contents buf

  let render ?(format="pdf") ?(title="FDD") t =
    Frenetic_kernel.Util.show_dot ~format ~title (to_dot t)

  let refs (t : t) : Int.Set.t =
    let rec f (node : t) (seen : Int.Set.t) =
      if Int.Set.mem seen node then
        seen
      else
        match D.unget node with
        | Leaf _ -> Int.Set.add seen node
        | Branch { tru=hi; fls=lo } ->
          Int.Set.add (f lo (f hi seen)) node in
    f t Int.Set.empty

  let rec node_to_sexp node =
    let open Sexplib.Sexp in
    match node with
    | Leaf r ->
      List [Atom "Leaf"; R.sexp_of_t r]
    | Branch { test; tru; fls } ->
      let tru = node_to_sexp @@ unget tru in
      let fls = node_to_sexp @@ unget fls in
      List [Atom "Branch"; sexp_of_v test; tru; fls]

  let rec node_of_sexp sexp =
    let open Sexplib.Sexp in
    match sexp with
    | List [Atom "Leaf"; sexp] ->
      get (Leaf (R.t_of_sexp sexp))
    | List [Atom "Branch"; test; tru; fls] ->
      let test = v_of_sexp test in
      let tru = node_of_sexp tru in
      let fls = node_of_sexp fls in
      mk_branch test tru fls
    | _ ->
      failwith "unsexpected s-expression!"


  let serialize (t : t) : string =
    unget t
    |> node_to_sexp
    |> Sexp.to_string

  let deserialize (s : string) : t =
    Sexp.of_string s
    |> node_of_sexp
end
OCaml

Innovation. Community. Security.