package catala

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file map.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
(* This file is part of the Catala compiler, a specification language for tax
   and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
   Louis Gesbert <louis.gesbert@inria.fr>

   Licensed under the Apache License, Version 2.0 (the "License"); you may not
   use this file except in compliance with the License. You may obtain a copy of
   the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
   WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
   License for the specific language governing permissions and limitations under
   the License. *)

(** Small wrapper on top of the stdlib [Map] module to add some useful functions *)

(* NOTE: only defines typed module, a .mli would be completely redundant *)

include Stdlib.Map

module type OrderedType = sig
  include Stdlib.Map.OrderedType

  val format : Format.formatter -> t -> unit
end

module type S = sig
  include Stdlib.Map.S

  exception Not_found of key
  (* Slightly more informative [Not_found] exception *)

  val find : key -> 'a t -> 'a
  val keys : 'a t -> key list
  val values : 'a t -> 'a list
  val of_list : (key * 'a) list -> 'a t
  val disjoint_union : 'a t -> 'a t -> 'a t

  val format_keys :
    ?pp_sep:(Format.formatter -> unit -> unit) ->
    Format.formatter ->
    'a t ->
    unit

  val format_values :
    ?pp_sep:(Format.formatter -> unit -> unit) ->
    (Format.formatter -> 'a -> unit) ->
    Format.formatter ->
    'a t ->
    unit

  val format_bindings :
    ?pp_sep:(Format.formatter -> unit -> unit) ->
    (Format.formatter -> (Format.formatter -> unit) -> 'a -> unit) ->
    Format.formatter ->
    'a t ->
    unit
  (** Formats the bindings of [t] in order. The user-supplied format function is
      provided with a formatter for keys (can be used with ["%t"]) and the
      corresponding value. *)

  val format :
    ?pp_sep:(Format.formatter -> unit -> unit) ->
    ?pp_bind:(Format.formatter -> unit -> unit) ->
    (Format.formatter -> 'a -> unit) ->
    Format.formatter ->
    'a t ->
    unit
  (** Formats all bindings of the map in order using the given separator
      (default ["; "]) and binding indicator (default [" = "]). *)
end

module Make (Ord : OrderedType) : S with type key = Ord.t = struct
  include Stdlib.Map.Make (Ord)

  exception Not_found of key

  let () =
    Printexc.register_printer
    @@ function
    | Not_found k ->
      Some (Format.asprintf "key '%a' not found in map" Ord.format k)
    | _ -> None

  let find k t = try find k t with Stdlib.Not_found -> raise (Not_found k)
  let keys t = fold (fun k _ acc -> k :: acc) t [] |> List.rev
  let values t = fold (fun _ v acc -> v :: acc) t [] |> List.rev
  let of_list l = List.fold_left (fun m (k, v) -> add k v m) empty l

  let disjoint_union t1 t2 =
    union
      (fun k _ _ ->
        Format.kasprintf failwith "Maps are not disjoint: conflict on key %a"
          Ord.format k)
      t1 t2

  let format_keys ?pp_sep ppf t =
    Format.pp_print_list ?pp_sep Ord.format ppf (keys t)

  let format_values ?pp_sep pp_value ppf t =
    Format.pp_print_list ?pp_sep pp_value ppf (values t)

  let format_bindings ?pp_sep pp_bnd ppf t =
    Format.pp_print_list ?pp_sep
      (fun ppf (k, v) -> pp_bnd ppf (fun ppf -> Ord.format ppf k) v)
      ppf (bindings t)

  let format
      ?(pp_sep = fun ppf () -> Format.fprintf ppf ";@ ")
      ?(pp_bind = fun ppf () -> Format.fprintf ppf " =@ ")
      pp_value
      ppf
      t =
    Format.pp_print_list ~pp_sep
      (fun ppf (key, value) ->
        Format.pp_open_hvbox ppf 2;
        Ord.format ppf key;
        pp_bind ppf ();
        pp_value ppf value;
        Format.pp_close_box ppf ())
      ppf (bindings t)
end
OCaml

Innovation. Community. Security.