mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Closure conversion fix
Support for manipulating toplevel functions as values was buggy, because the recursion after eta-expansion would fall into the pattern for a `let..in` and not do the expected transformation. The patch explicitely builds the closure in that case, avoiding such issues with recursion.
This commit is contained in:
parent
23b196aace
commit
21cea5c968
@ -52,6 +52,84 @@ let join_vars : ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t =
|
||||
|
||||
(** {1 Transforming closures}*)
|
||||
|
||||
let build_closure :
|
||||
type m.
|
||||
m ctx ->
|
||||
(m expr Var.t * m mark) list ->
|
||||
m expr boxed ->
|
||||
m expr Var.t array ->
|
||||
typ list ->
|
||||
m mark ->
|
||||
m expr boxed =
|
||||
fun ctx free_vars body args tys m ->
|
||||
(* λ x.t *)
|
||||
let pos = Expr.mark_pos m in
|
||||
let mark_ty ty = Expr.with_ty m ty in
|
||||
let free_vars_types = List.map (fun (_, m) -> Expr.maybe_ty m) free_vars in
|
||||
(* x1, ..., xn *)
|
||||
let code_var = Var.make ctx.name_context in
|
||||
(* code *)
|
||||
let closure_env_arg_var = Var.make "env" in
|
||||
let closure_env_var = Var.make "env" in
|
||||
let env_ty = TTuple free_vars_types, pos in
|
||||
(* let env = from_closure_env env in let arg0 = env.0 in ... *)
|
||||
let new_closure_body =
|
||||
Expr.make_let_in closure_env_var env_ty
|
||||
(Expr.eappop
|
||||
~op:(Operator.FromClosureEnv, pos)
|
||||
~tys:[TClosureEnv, pos]
|
||||
~args:[Expr.evar closure_env_arg_var (mark_ty (TClosureEnv, pos))]
|
||||
(mark_ty env_ty))
|
||||
(Expr.make_multiple_let_in
|
||||
(Array.of_list (List.map fst free_vars))
|
||||
free_vars_types
|
||||
(List.mapi
|
||||
(fun i _ ->
|
||||
Expr.make_tupleaccess
|
||||
(Expr.evar closure_env_var (mark_ty env_ty))
|
||||
i (List.length free_vars) pos)
|
||||
free_vars)
|
||||
body pos)
|
||||
pos
|
||||
in
|
||||
(* fun env arg0 ... -> new_closure_body *)
|
||||
let new_closure =
|
||||
Expr.make_abs
|
||||
(Array.append [| closure_env_arg_var |] args)
|
||||
new_closure_body
|
||||
((TClosureEnv, pos) :: tys)
|
||||
pos
|
||||
in
|
||||
let new_closure_ty = Expr.maybe_ty (Mark.get new_closure) in
|
||||
Expr.make_let_in code_var new_closure_ty new_closure
|
||||
(Expr.make_tuple
|
||||
((Bindlib.box_var code_var, mark_ty new_closure_ty)
|
||||
:: [
|
||||
Expr.eappop
|
||||
~op:(Operator.ToClosureEnv, pos)
|
||||
~tys:
|
||||
[
|
||||
( (if free_vars = [] then TLit TUnit
|
||||
else TTuple free_vars_types),
|
||||
pos );
|
||||
]
|
||||
~args:
|
||||
[
|
||||
(if free_vars = [] then
|
||||
Expr.elit LUnit (mark_ty (TLit TUnit, pos))
|
||||
else
|
||||
Expr.etuple
|
||||
(List.map
|
||||
(fun (extra_var, m) ->
|
||||
Bindlib.box_var extra_var, Expr.with_pos pos m)
|
||||
free_vars)
|
||||
(mark_ty (TTuple free_vars_types, pos)));
|
||||
]
|
||||
(mark_ty (TClosureEnv, pos));
|
||||
])
|
||||
m)
|
||||
pos
|
||||
|
||||
(** Returns the expression with closed closures and the set of free variables
|
||||
inside this new expression. Implementation guided by
|
||||
http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=10
|
||||
@ -71,7 +149,7 @@ let rec transform_closures_expr :
|
||||
| EVar v -> (
|
||||
match Var.Map.find_opt v ctx.globally_bound_vars with
|
||||
| None -> Var.Map.singleton v m, (Bindlib.box_var v, m)
|
||||
| Some (TArrow (targs, tret), _) ->
|
||||
| Some ((TArrow (targs, tret), _) as fty) ->
|
||||
(* Here we eta-expand the argument to make sure function pointers are
|
||||
correctly casted as closures *)
|
||||
let args = Array.init (List.length targs) (fun _ -> Var.make "eta_arg") in
|
||||
@ -80,26 +158,15 @@ let rec transform_closures_expr :
|
||||
(fun v ty -> Expr.evar v (Expr.with_ty m ty))
|
||||
(Array.to_list args) targs
|
||||
in
|
||||
let e =
|
||||
Expr.eabs
|
||||
(Expr.bind args
|
||||
(Expr.eapp ~f:(Expr.rebox e) ~args:arg_vars ~tys:targs
|
||||
(Expr.with_ty m tret)))
|
||||
targs m
|
||||
in
|
||||
let boxed =
|
||||
let ctx =
|
||||
(* We hide the type of the toplevel definition so that the function
|
||||
doesn't loop *)
|
||||
{
|
||||
ctx with
|
||||
globally_bound_vars =
|
||||
Var.Map.add v (Expr.maybe_ty m) ctx.globally_bound_vars;
|
||||
}
|
||||
let closure =
|
||||
let body =
|
||||
Expr.eapp
|
||||
~f:(Bindlib.box_var v, Expr.with_ty m fty)
|
||||
~args:arg_vars ~tys:targs (Expr.with_ty m tret)
|
||||
in
|
||||
Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e)
|
||||
build_closure ctx [] body args targs m
|
||||
in
|
||||
Bindlib.unbox boxed
|
||||
Var.Map.empty, closure
|
||||
| Some _ -> Var.Map.empty, (Bindlib.box_var v, m))
|
||||
| EMatch { e; cases; name } ->
|
||||
let free_vars, new_e = (transform_closures_expr ctx) e in
|
||||
@ -147,89 +214,15 @@ let rec transform_closures_expr :
|
||||
~f:(Expr.eabs new_binder (List.map translate_type tys) e1_pos)
|
||||
~args:new_args ~tys m )
|
||||
| EAbs { binder; tys } ->
|
||||
(* λ x.t *)
|
||||
let binder_pos = Expr.mark_pos m in
|
||||
let mark_ty ty = Expr.with_ty m ty in
|
||||
(* Converting the closure. *)
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
(* t *)
|
||||
let body_vars, new_body = (transform_closures_expr ctx) body in
|
||||
let free_vars, body = (transform_closures_expr ctx) body in
|
||||
(* [[t]] *)
|
||||
let extra_vars =
|
||||
Array.fold_left (fun m v -> Var.Map.remove v m) body_vars vars
|
||||
let free_vars =
|
||||
Array.fold_left (fun m v -> Var.Map.remove v m) free_vars vars
|
||||
in
|
||||
let extra_vars_list = Var.Map.bindings extra_vars in
|
||||
let extra_vars_types =
|
||||
List.map (fun (_, m) -> Expr.maybe_ty m) extra_vars_list
|
||||
in
|
||||
(* x1, ..., xn *)
|
||||
let code_var = Var.make ctx.name_context in
|
||||
(* code *)
|
||||
let closure_env_arg_var = Var.make "env" in
|
||||
let closure_env_var = Var.make "env" in
|
||||
let env_ty = TTuple extra_vars_types, binder_pos in
|
||||
(* let env = from_closure_env env in let arg0 = env.0 in ... *)
|
||||
let new_closure_body =
|
||||
Expr.make_let_in closure_env_var env_ty
|
||||
(Expr.eappop
|
||||
~op:(Operator.FromClosureEnv, binder_pos)
|
||||
~tys:[TClosureEnv, binder_pos]
|
||||
~args:
|
||||
[Expr.evar closure_env_arg_var (mark_ty (TClosureEnv, binder_pos))]
|
||||
(mark_ty env_ty))
|
||||
(Expr.make_multiple_let_in
|
||||
(Array.of_list (List.map fst extra_vars_list))
|
||||
extra_vars_types
|
||||
(List.mapi
|
||||
(fun i _ ->
|
||||
Expr.make_tupleaccess
|
||||
(Expr.evar closure_env_var (mark_ty env_ty))
|
||||
i
|
||||
(List.length extra_vars_list)
|
||||
binder_pos)
|
||||
extra_vars_list)
|
||||
new_body binder_pos)
|
||||
binder_pos
|
||||
in
|
||||
(* fun env arg0 ... -> new_closure_body *)
|
||||
let new_closure =
|
||||
Expr.make_abs
|
||||
(Array.append [| closure_env_arg_var |] vars)
|
||||
new_closure_body
|
||||
((TClosureEnv, binder_pos) :: tys)
|
||||
(Expr.pos e)
|
||||
in
|
||||
let new_closure_ty = Expr.maybe_ty (Mark.get new_closure) in
|
||||
( extra_vars,
|
||||
Expr.make_let_in code_var new_closure_ty new_closure
|
||||
(Expr.make_tuple
|
||||
((Bindlib.box_var code_var, mark_ty new_closure_ty)
|
||||
:: [
|
||||
Expr.eappop
|
||||
~op:(Operator.ToClosureEnv, binder_pos)
|
||||
~tys:
|
||||
[
|
||||
( (if extra_vars_list = [] then TLit TUnit
|
||||
else TTuple extra_vars_types),
|
||||
binder_pos );
|
||||
]
|
||||
~args:
|
||||
[
|
||||
(if extra_vars_list = [] then
|
||||
Expr.elit LUnit (mark_ty (TLit TUnit, binder_pos))
|
||||
else
|
||||
Expr.etuple
|
||||
(List.map
|
||||
(fun (extra_var, m) ->
|
||||
( Bindlib.box_var extra_var,
|
||||
Expr.with_pos binder_pos m ))
|
||||
extra_vars_list)
|
||||
(mark_ty (TTuple extra_vars_types, binder_pos)));
|
||||
]
|
||||
(mark_ty (TClosureEnv, binder_pos));
|
||||
])
|
||||
m)
|
||||
(Expr.pos e) )
|
||||
free_vars, build_closure ctx (Var.Map.bindings free_vars) body vars tys m
|
||||
| EAppOp
|
||||
{
|
||||
op = ((HandleDefaultOpt | Fold | Map | Filter | Reduce), _) as op;
|
||||
@ -318,10 +311,7 @@ let rec transform_closures_expr :
|
||||
pos)
|
||||
pos
|
||||
in
|
||||
( free_vars,
|
||||
Expr.make_let_in code_env_var
|
||||
(TAny, Expr.pos e)
|
||||
new_e1 call_expr (Expr.pos e) )
|
||||
free_vars, Expr.make_let_in code_env_var (TAny, pos) new_e1 call_expr pos
|
||||
| _ -> .
|
||||
|
||||
let transform_closures_scope_let ctx scope_body_expr =
|
||||
@ -674,6 +664,4 @@ let hoist_closures_program (p : 'm program) : 'm program Bindlib.box =
|
||||
let closure_conversion (p : 'm program) : 'm program =
|
||||
let new_p = transform_closures_program p in
|
||||
let new_p = hoist_closures_program (Bindlib.unbox new_p) in
|
||||
(* FIXME: either fix the types of the marks, or remove the types annotations
|
||||
during the main processing (rather than requiring a new traversal) *)
|
||||
Bindlib.unbox new_p
|
||||
|
@ -100,7 +100,7 @@ val program :
|
||||
('a, 'm) gexpr program ->
|
||||
('a, typed) gexpr program
|
||||
(** Typing on whole programs (as defined in Shared_ast.program, i.e. for the
|
||||
later dcalc/lcalc stages.
|
||||
later dcalc/lcalc stages).
|
||||
|
||||
Any existing type annotations are checked for unification. Use
|
||||
[Program.untype] to remove them beforehand if this is not the desired
|
||||
|
@ -41,3 +41,69 @@ let scope S (S_in: S_in {x_in: bool}): S {z: integer} =
|
||||
return { S z = z; }
|
||||
|
||||
```
|
||||
|
||||
|
||||
```catala
|
||||
declaration scope S2:
|
||||
output dummy content boolean
|
||||
input output cfun2 content decimal depends on x content integer
|
||||
|
||||
scope S2:
|
||||
definition dummy equals false
|
||||
|
||||
declaration scope S2Use:
|
||||
internal fun content decimal depends on y content integer
|
||||
output o content (S2, S2)
|
||||
|
||||
declaration fun2 content decimal depends on y content integer equals y / 3
|
||||
|
||||
scope S2Use:
|
||||
definition fun of y equals y / 2
|
||||
definition o equals
|
||||
(output of S2 with { -- cfun2: fun },
|
||||
output of S2 with { -- cfun2: fun2 })
|
||||
```
|
||||
|
||||
```catala-test-inline
|
||||
$ catala Lcalc --avoid-exceptions -O --closure-conversion -s S2Use
|
||||
let scope S2Use
|
||||
(S2Use_in: S2Use_in)
|
||||
: S2Use {
|
||||
o:
|
||||
(S2 {
|
||||
dummy: bool;
|
||||
cfun2: ((closure_env, integer) → decimal, closure_env)
|
||||
},
|
||||
S2 {
|
||||
dummy: bool;
|
||||
cfun2: ((closure_env, integer) → decimal, closure_env)
|
||||
})
|
||||
}
|
||||
=
|
||||
let set fun : ((closure_env, integer) → decimal, closure_env) =
|
||||
(closure_fun, to_closure_env ())
|
||||
in
|
||||
let set o :
|
||||
(S2 {
|
||||
dummy: bool;
|
||||
cfun2: ((closure_env, integer) → decimal, closure_env)
|
||||
},
|
||||
S2 {
|
||||
dummy: bool;
|
||||
cfun2: ((closure_env, integer) → decimal, closure_env)
|
||||
}) =
|
||||
(let result : S2 = S2 { S2_in cfun2_in = fun; } in
|
||||
{ S2
|
||||
dummy = result.dummy;
|
||||
cfun2 = (closure_o, to_closure_env (result));
|
||||
},
|
||||
let result : S2 =
|
||||
S2 { S2_in cfun2_in = (closure_o, to_closure_env ()); }
|
||||
in
|
||||
{ S2
|
||||
dummy = result.dummy;
|
||||
cfun2 = (closure_o, to_closure_env (result));
|
||||
})
|
||||
in
|
||||
return { S2Use o = o; }
|
||||
```
|
||||
|
Loading…
Reference in New Issue
Block a user