package ppx_enumerate

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file ppx_enumerate.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
open Base
open Ppxlib
open Ast_builder.Default

let name_of_type_name = function
  | "t" -> "all"
  | type_name -> "all_of_" ^ type_name
let name_of_type_variable str =
  "_" ^ name_of_type_name str

(* Utility functions *)
let enumeration_type_of_td td =
  let init =
    let tp = core_type_of_type_declaration td in
    let loc = tp.ptyp_loc in
    [%type: [%t tp] list]
  in
  List.fold_right td.ptype_params ~init
    ~f:(fun (tp, _variance) acc ->
      let loc = tp.ptyp_loc in
      [%type: [%t tp] list -> [%t acc] ])
;;

let sig_of_td td =
  let td = name_type_params_in_td td in
  let enumeration_type = enumeration_type_of_td td in
  let name = name_of_type_name td.ptype_name.txt in
  let loc = td.ptype_loc in
  psig_value ~loc (value_description ~loc ~name:(Located.mk ~loc name)
                     ~type_:enumeration_type ~prim:[])

let sig_of_tds ~loc ~path:_ (_rec_flag, tds) =
  let sg_name = "Ppx_enumerate_lib.Enumerable.S" in
  match
    mk_named_sig tds ~loc ~sg_name ~handle_polymorphic_variant:true
  with
  | Some include_infos -> [ psig_include ~loc include_infos ]
  | None -> List.map tds ~f:sig_of_td

let gen_symbol = gen_symbol ~prefix:"enumerate"

let tuple loc exprs =
  assert (List.length exprs >= 2);
  pexp_tuple ~loc exprs
let patt_tuple loc pats =
  assert (List.length pats >= 2);
  ppat_tuple ~loc pats
let apply e el = eapply ~loc:e.pexp_loc e el

let replace_variables_by_underscores =
  let map = object
    inherit Ast_traverse.map as super
    method! core_type_desc ty =
      match super#core_type_desc ty with
      | Ptyp_var _ -> Ptyp_any
      | ty -> ty
  end in
  map#core_type
;;

let list_map loc l ~f =
  let element = gen_symbol () in
  let applied = f (evar ~loc element) in
  [%expr
    let rec map l acc = match l with
      | [] -> Ppx_enumerate_lib.List.rev acc
      | [%p pvar ~loc element] :: l ->
        map l ([%e applied] ::acc)
    in
    map [%e l] []
  ]

(* [cartesian_product_map l's f loc] takes a list of expressions of type list, and
   returns code generating the Cartesian product of those lists, with [f] applied to each
   tuple.
*)
let cartesian_product_map ~exhaust_check l's ~f loc =
  match l's with
  | [] -> Location.raise_errorf ~loc "cartesian_product_map passed list of zero length"
  | [l] -> list_map loc l ~f:(fun x -> f [x])
  | _ ->
    let lid x =  evar ~loc x in
    let patt_lid x = pvar ~loc x in
    let alias_vars = List.map l's ~f:(fun _ -> gen_symbol ()) in
    let init =
      let len = List.length l's in
      let hd_vars = List.map l's ~f:(fun _ -> gen_symbol ()) in
      let args_vars = List.map l's ~f:(fun _ -> gen_symbol ()) in
      let tl_var = gen_symbol () in
      let base_case =
        let patts =
          List.rev ([%pat? []] :: List.init (len - 1) ~f:(fun _ -> [%pat? _]))
        in
        case ~guard:None ~lhs:(patt_tuple loc patts) ~rhs:[%expr Ppx_enumerate_lib.List.rev acc]
      in
      let apply_case =
        let patts = List.mapi hd_vars ~f:(fun i x ->
          [%pat? ([%p pvar ~loc x] :: [%p if i = 0 then
                      patt_lid tl_var
                    else
                      ppat_any ~loc])])
        in
        case ~guard:None ~lhs:(patt_tuple loc patts)
          ~rhs:(apply [%expr loop ([%e f (List.map hd_vars ~f:lid)] :: acc)]
                  (evar ~loc tl_var :: List.map (List.tl_exn args_vars) ~f:lid))
      in
      let decrement_cases =
        List.init (len - 1) ~f:(fun i ->
          let patts = List.init i ~f:(fun _ -> ppat_any ~loc)
                      @ [ [%pat? [] ]; [%pat?  (_ :: [%p pvar ~loc tl_var]) ] ]
                      @ List.init (len - i - 2) ~f:(fun _ -> ppat_any ~loc)
          in
          case ~guard:None ~lhs:(patt_tuple loc patts)
            ~rhs:(apply [%expr loop acc ]
                    (List.map ~f:lid (List.take alias_vars (i + 1))
                     @ evar ~loc tl_var ::
                       (List.map ~f:lid (List.drop args_vars (i + 2))))))
      in
      let decrement_cases =
        if exhaust_check then
          decrement_cases
        else
          decrement_cases @ [
            case ~guard:None ~lhs:(ppat_any ~loc) ~rhs:[%expr assert false ]
          ]
      in
      let match_exp =
        pexp_match ~loc (tuple loc (List.map args_vars ~f:lid))
          (base_case :: apply_case :: decrement_cases)
      in
      let match_exp =
        if exhaust_check then
          match_exp
        else
          let loc = Location.none in
          { match_exp with
            pexp_attributes = [
              attribute
                ~loc
                ~name:(Location.{ txt = "ocaml.warning"; loc })
                ~payload:(PStr [ pstr_eval ~loc (estring ~loc "-11") [] ])
            ]
          }
      in
      [%expr
        let rec loop acc =
          [%e eabstract ~loc (List.map args_vars ~f:patt_lid) match_exp]
        in
        [%e apply [%expr loop []] (List.map ~f:lid alias_vars)]
      ]
    in
    Stdlib.ListLabels.fold_right2 alias_vars l's ~init ~f:(fun alias_var input_list acc ->
      [%expr
        let [%p pvar ~loc alias_var] = [%e input_list] in
        [%e acc]
      ])

(* Here we do two things: simplify append on static lists, to make the generated code more
   readable and rewrite (List.append (List.append a b) c) as (List.append a (List.append b
   c)), to avoid a quadratic behaviour with long nesting to the left. *)
let rec list_append loc l1 l2 =
  match l2 with
  | [%expr [] ] -> l1
  | _ ->
    match l1 with
    | [%expr [] ] -> l2
    | [%expr [%e? hd] :: [%e? tl] ] -> [%expr [%e hd] :: [%e list_append loc tl l2] ]
    | [%expr Ppx_enumerate_lib.List.append [%e? ll] [%e? lr] ] -> list_append loc ll (list_append loc lr l2)
    | _ ->
      [%expr  Ppx_enumerate_lib.List.append [%e l1] [%e l2] ]

let rec enum ~exhaust_check ~main_type ty =
  let loc = { ty.ptyp_loc with loc_ghost = true } in
  match ty.ptyp_desc with
  | Ptyp_constr ({ txt = Lident "bool"; _ }, []) -> [%expr [false; true] ]
  | Ptyp_constr ({ txt = Lident "unit"; _ }, []) -> [%expr  [()] ]
  | Ptyp_constr ({ txt = Lident "option"; _ }, [tp]) ->
    [%expr (None :: [%e list_map loc (enum ~exhaust_check ~main_type:tp tp)
                          ~f:(fun e -> [%expr Some [%e e]])]
           )
    ]
  | Ptyp_constr (id, args) ->
    type_constr_conv ~loc id ~f:name_of_type_name
      (List.map args ~f:(fun t -> enum ~exhaust_check t ~main_type:t))
  | Ptyp_tuple tps -> product ~exhaust_check loc tps (fun exprs -> tuple loc exprs)
  | Ptyp_variant (row_fields, Closed, None) ->
    List.fold_left row_fields ~init:[%expr []] ~f:(fun acc rf ->
      list_append loc acc (variant_case ~exhaust_check loc rf ~main_type))
  | Ptyp_var id -> evar ~loc (name_of_type_variable id)
  | _ -> Location.raise_errorf ~loc "ppx_enumerate: unsupported type"

and variant_case ~exhaust_check loc row_field ~main_type =
  match row_field.prf_desc with
  | Rtag ({ txt = cnstr; _ }, true, _) | Rtag ({ txt = cnstr; _ }, _, []) ->
    [%expr [ [%e pexp_variant ~loc cnstr None] ] ]
  | Rtag ({ txt = cnstr; _ }, false, tp :: _) ->
    list_map loc (enum ~exhaust_check tp ~main_type:tp) ~f:(fun e ->
      pexp_variant ~loc cnstr (Some e))
  | Rinherit ty ->
    let e = enum ~exhaust_check ~main_type ty in
    [%expr ([%e e] :> [%t replace_variables_by_underscores main_type] list) ]

and constructor_case ~exhaust_check loc cd =
  match cd.pcd_args with
  | Pcstr_tuple [] -> [%expr [ [%e econstruct cd None ] ] ]
  | Pcstr_tuple tps ->
    product ~exhaust_check loc tps (fun x ->
      econstruct cd (Some (pexp_tuple ~loc x)))
  | Pcstr_record lds ->
    enum_of_lab_decs ~exhaust_check ~loc lds ~k:(fun x -> econstruct cd (Some x))

and enum_of_lab_decs ~exhaust_check ~loc lds ~k =
  let field_names, types =
    List.unzip (
      List.map lds ~f:(fun ld -> (ld.pld_name, ld.pld_type)))
  in
  product ~exhaust_check loc types (function l ->
    let fields =
      List.map2_exn field_names l ~f:(fun field_name x ->
        (Located.map lident field_name, x))
    in
    k (pexp_record ~loc fields None)
  )

and product ~exhaust_check loc tps f =
  let all = List.map tps ~f:(fun tp -> enum ~exhaust_check ~main_type:tp tp) in
  cartesian_product_map ~exhaust_check all loc ~f

let quantify loc tps typ =
  match tps with
  | [] -> typ
  | _  ->
    ptyp_poly ~loc (List.map tps ~f:(fun x -> (get_type_param_name x))) typ

let enum_of_td ~exhaust_check td =
  let td = name_type_params_in_td td in
  let loc = td.ptype_loc in
  let all =
    let main_type =
      ptyp_constr ~loc (Located.map lident td.ptype_name)
        (List.map td.ptype_params ~f:(fun _ -> ptyp_any ~loc))
    in
    match td.ptype_kind with
    | Ptype_variant cds ->
      (* Process [cd] elements in same order as camlp4 to avoid code-gen diffs caused by
         different order of [gen_symbol] calls *)
      List.fold_left cds ~init:[%expr []] ~f:(fun acc cd ->
        list_append loc acc (constructor_case ~exhaust_check loc cd))
    | Ptype_record lds -> enum_of_lab_decs ~exhaust_check ~loc lds ~k:(fun x -> x)
    | Ptype_open ->
      Location.raise_errorf ~loc "ppx_enumerate: open types not supported"
    | Ptype_abstract ->
      match td.ptype_manifest with
      | None -> [%expr [] ]
      | Some tp -> enum ~exhaust_check tp ~main_type
  in
  let name = name_of_type_name td.ptype_name.txt in
  let args = List.map td.ptype_params ~f:(fun ((tp, _) as x) ->
    let name = name_of_type_variable (get_type_param_name x).txt in
    let loc = tp.ptyp_loc in
    pvar ~loc name
  )
  in
  let enumeration_type =
    let typ = enumeration_type_of_td td in
    quantify loc td.ptype_params typ
  in
  let body = eabstract ~loc args all in
  let zero_args = (List.length args = 0) in
  if zero_args (* constrain body rather than pattern *)
  then [%str let [%p pvar ~loc name] = ([%e body] : [%t enumeration_type]) ]
  else [%str let [%p pvar ~loc name] : [%t enumeration_type] = [%e body] ]


let enumerate =
  let str_args = Deriving.Args.(empty +> flag "no_exhaustiveness_check") in
  Deriving.add "enumerate"
    ~str_type_decl:(Deriving.Generator.make str_args
                      (fun ~loc ~path:_ (_rec, tds) no_exhaustiveness_check ->
                         match tds with
                         | [td] ->
                           enum_of_td ~exhaust_check:(not no_exhaustiveness_check) td
                         | _ -> Location.raise_errorf ~loc
                                  "only one type at a time is support by ppx_enumerate"))
    ~sig_type_decl:(Deriving.Generator.make Deriving.Args.empty sig_of_tds)

let () =
  Deriving.add "all"
    ~extension:(fun ~loc:_ ~path:_ ty -> enum ~exhaust_check:true ty ~main_type:ty)
  |> Deriving.ignore
OCaml

Innovation. Community. Security.