package catala

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file closure_conversion.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
(* This file is part of the Catala compiler, a specification language for tax
   and social benefits computation rules. Copyright (C) 2022 Inria, contributor:
   Denis Merigoux <denis.merigoux@inria.fr>

   Licensed under the Apache License, Version 2.0 (the "License"); you may not
   use this file except in compliance with the License. You may obtain a copy of
   the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
   WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
   License for the specific language governing permissions and limitations under
   the License. *)

open Catala_utils
open Shared_ast
open Ast
module D = Dcalc.Ast

(** TODO: This version is not yet debugged and ought to be specialized when
    Lcalc has more structure. *)

type 'm ctx = { name_context : string; globally_bound_vars : 'm expr Var.Set.t }

(** Returns the expression with closed closures and the set of free variables
    inside this new expression. Implementation guided by
    http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=9. *)
let closure_conversion_expr (type m) (ctx : m ctx) (e : m expr) : m expr boxed =
  let rec aux e =
    let m = Marked.get_mark e in
    match Marked.unmark e with
    | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
    | EArray _ | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _
    | ECatch _ ->
      Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
    | EVar v ->
      ( (if Var.Set.mem v ctx.globally_bound_vars then Var.Set.empty
        else Var.Set.singleton v),
        (Bindlib.box_var v, m) )
    | EMatch { e; cases; name } ->
      let free_vars, new_e = aux e in
      (* We do not close the clotures inside the arms of the match expression,
         since they get a special treatment at compilation to Scalc. *)
      let free_vars, new_cases =
        EnumConstructor.Map.fold
          (fun cons e1 (free_vars, new_cases) ->
            match Marked.unmark e1 with
            | EAbs { binder; tys } ->
              let vars, body = Bindlib.unmbind binder in
              let new_free_vars, new_body = aux body in
              let new_binder = Expr.bind vars new_body in
              ( Var.Set.union free_vars new_free_vars,
                EnumConstructor.Map.add cons
                  (Expr.eabs new_binder tys (Marked.get_mark e1))
                  new_cases )
            | _ -> failwith "should not happen")
          cases
          (free_vars, EnumConstructor.Map.empty)
      in
      free_vars, Expr.ematch new_e name new_cases m
    | EApp { f = EAbs { binder; tys }, e1_pos; args } ->
      (* let-binding, we should not close these *)
      let vars, body = Bindlib.unmbind binder in
      let free_vars, new_body = aux body in
      let new_binder = Expr.bind vars new_body in
      let free_vars, new_args =
        List.fold_right
          (fun arg (free_vars, new_args) ->
            let new_free_vars, new_arg = aux arg in
            Var.Set.union free_vars new_free_vars, new_arg :: new_args)
          args (free_vars, [])
      in
      free_vars, Expr.eapp (Expr.eabs new_binder tys e1_pos) new_args m
    | EAbs { binder; tys } ->
      (* λ x.t *)
      let binder_mark = m in
      let binder_pos = Expr.mark_pos binder_mark in
      (* Converting the closure. *)
      let vars, body = Bindlib.unmbind binder in
      (* t *)
      let body_vars, new_body = aux body in
      (* [[t]] *)
      let extra_vars =
        Var.Set.diff body_vars (Var.Set.of_list (Array.to_list vars))
      in
      let extra_vars_list = Var.Set.elements extra_vars in
      (* x1, ..., xn *)
      let code_var = Var.make ctx.name_context in
      (* code *)
      let inner_c_var = Var.make "env" in
      let any_ty = TAny, binder_pos in
      let new_closure_body =
        Expr.make_multiple_let_in
          (Array.of_list extra_vars_list)
          (List.map (fun _ -> any_ty) extra_vars_list)
          (List.mapi
             (fun i _ ->
               Expr.etupleaccess
                 (Expr.evar inner_c_var binder_mark)
                 (i + 1)
                 (List.length extra_vars_list)
                 binder_mark)
             extra_vars_list)
          new_body
          (Expr.mark_pos binder_mark)
      in
      let new_closure =
        Expr.make_abs
          (Array.concat [Array.make 1 inner_c_var; vars])
          new_closure_body
          ((TAny, binder_pos) :: tys)
          (Expr.pos e)
      in
      ( extra_vars,
        Expr.make_let_in code_var
          (TAny, Expr.pos e)
          new_closure
          (Expr.etuple
             ((Bindlib.box_var code_var, binder_mark)
             :: List.map
                  (fun extra_var -> Bindlib.box_var extra_var, binder_mark)
                  extra_vars_list)
             m)
          (Expr.pos e) )
    | EApp { f = EOp _, _; _ } ->
      (* This corresponds to an operator call, which we don't want to
         transform*)
      Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
    | EApp { f = EVar v, _; _ } when Var.Set.mem v ctx.globally_bound_vars ->
      (* This corresponds to a scope call, which we don't want to transform*)
      Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:aux e
    | EApp { f = e1; args } ->
      let free_vars, new_e1 = aux e1 in
      let env_var = Var.make "env" in
      let code_var = Var.make "code" in
      let free_vars, new_args =
        List.fold_right
          (fun arg (free_vars, new_args) ->
            let new_free_vars, new_arg = aux arg in
            Var.Set.union free_vars new_free_vars, new_arg :: new_args)
          args (free_vars, [])
      in
      let call_expr =
        let m1 = Marked.get_mark e1 in
        Expr.make_let_in code_var
          (TAny, Expr.pos e)
          (Expr.etupleaccess
             (Bindlib.box_var env_var, m1)
             0
             (List.length new_args + 1)
             m)
          (Expr.eapp
             (Bindlib.box_var code_var, m1)
             ((Bindlib.box_var env_var, m1) :: new_args)
             m)
          (Expr.pos e)
      in
      ( free_vars,
        Expr.make_let_in env_var
          (TAny, Expr.pos e)
          new_e1 call_expr (Expr.pos e) )
  in
  let _vars, e' = aux e in
  e'

let closure_conversion (p : 'm program) : 'm program Bindlib.box =
  let _, new_code_items =
    Scope.fold_map
      ~f:(fun toplevel_vars var code_item ->
        ( Var.Set.add var toplevel_vars,
          match code_item with
          | ScopeDef (name, body) ->
            let scope_input_var, scope_body_expr =
              Bindlib.unbind body.scope_body_expr
            in
            let ctx =
              {
                name_context = Marked.unmark (ScopeName.get_info name);
                globally_bound_vars = toplevel_vars;
              }
            in
            let new_scope_lets =
              Scope.map_exprs_in_lets
                ~f:(closure_conversion_expr ctx)
                ~varf:(fun v -> v)
                scope_body_expr
            in
            let new_scope_body_expr =
              Bindlib.bind_var scope_input_var new_scope_lets
            in
            Bindlib.box_apply
              (fun scope_body_expr ->
                ScopeDef (name, { body with scope_body_expr }))
              new_scope_body_expr
          | Topdef (name, ty, expr) ->
            let ctx =
              {
                name_context = Marked.unmark (TopdefName.get_info name);
                globally_bound_vars = toplevel_vars;
              }
            in
            Bindlib.box_apply
              (fun e -> Topdef (name, ty, e))
              (Expr.Box.lift (closure_conversion_expr ctx expr)) ))
      ~varf:(fun v -> v)
      (Var.Set.of_list
         (List.map Var.translate [handle_default; handle_default_opt]))
      p.code_items
  in
  Bindlib.box_apply
    (fun new_code_items -> { p with code_items = new_code_items })
    new_code_items
OCaml

Innovation. Community. Security.