Closure conversion fix

Support for manipulating toplevel functions as values was buggy, because the
recursion after eta-expansion would fall into the pattern for a `let..in` and
not do the expected transformation.

The patch explicitely builds the closure in that case, avoiding such issues with
recursion.
This commit is contained in:
Louis Gesbert 2024-06-20 15:08:16 +02:00
parent 23b196aace
commit 21cea5c968
3 changed files with 158 additions and 104 deletions

View File

@ -52,6 +52,84 @@ let join_vars : ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t =
(** {1 Transforming closures}*)
let build_closure :
type m.
m ctx ->
(m expr Var.t * m mark) list ->
m expr boxed ->
m expr Var.t array ->
typ list ->
m mark ->
m expr boxed =
fun ctx free_vars body args tys m ->
(* λ x.t *)
let pos = Expr.mark_pos m in
let mark_ty ty = Expr.with_ty m ty in
let free_vars_types = List.map (fun (_, m) -> Expr.maybe_ty m) free_vars in
(* x1, ..., xn *)
let code_var = Var.make ctx.name_context in
(* code *)
let closure_env_arg_var = Var.make "env" in
let closure_env_var = Var.make "env" in
let env_ty = TTuple free_vars_types, pos in
(* let env = from_closure_env env in let arg0 = env.0 in ... *)
let new_closure_body =
Expr.make_let_in closure_env_var env_ty
(Expr.eappop
~op:(Operator.FromClosureEnv, pos)
~tys:[TClosureEnv, pos]
~args:[Expr.evar closure_env_arg_var (mark_ty (TClosureEnv, pos))]
(mark_ty env_ty))
(Expr.make_multiple_let_in
(Array.of_list (List.map fst free_vars))
free_vars_types
(List.mapi
(fun i _ ->
Expr.make_tupleaccess
(Expr.evar closure_env_var (mark_ty env_ty))
i (List.length free_vars) pos)
free_vars)
body pos)
pos
in
(* fun env arg0 ... -> new_closure_body *)
let new_closure =
Expr.make_abs
(Array.append [| closure_env_arg_var |] args)
new_closure_body
((TClosureEnv, pos) :: tys)
pos
in
let new_closure_ty = Expr.maybe_ty (Mark.get new_closure) in
Expr.make_let_in code_var new_closure_ty new_closure
(Expr.make_tuple
((Bindlib.box_var code_var, mark_ty new_closure_ty)
:: [
Expr.eappop
~op:(Operator.ToClosureEnv, pos)
~tys:
[
( (if free_vars = [] then TLit TUnit
else TTuple free_vars_types),
pos );
]
~args:
[
(if free_vars = [] then
Expr.elit LUnit (mark_ty (TLit TUnit, pos))
else
Expr.etuple
(List.map
(fun (extra_var, m) ->
Bindlib.box_var extra_var, Expr.with_pos pos m)
free_vars)
(mark_ty (TTuple free_vars_types, pos)));
]
(mark_ty (TClosureEnv, pos));
])
m)
pos
(** Returns the expression with closed closures and the set of free variables
inside this new expression. Implementation guided by
http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=10
@ -71,7 +149,7 @@ let rec transform_closures_expr :
| EVar v -> (
match Var.Map.find_opt v ctx.globally_bound_vars with
| None -> Var.Map.singleton v m, (Bindlib.box_var v, m)
| Some (TArrow (targs, tret), _) ->
| Some ((TArrow (targs, tret), _) as fty) ->
(* Here we eta-expand the argument to make sure function pointers are
correctly casted as closures *)
let args = Array.init (List.length targs) (fun _ -> Var.make "eta_arg") in
@ -80,26 +158,15 @@ let rec transform_closures_expr :
(fun v ty -> Expr.evar v (Expr.with_ty m ty))
(Array.to_list args) targs
in
let e =
Expr.eabs
(Expr.bind args
(Expr.eapp ~f:(Expr.rebox e) ~args:arg_vars ~tys:targs
(Expr.with_ty m tret)))
targs m
in
let boxed =
let ctx =
(* We hide the type of the toplevel definition so that the function
doesn't loop *)
{
ctx with
globally_bound_vars =
Var.Map.add v (Expr.maybe_ty m) ctx.globally_bound_vars;
}
let closure =
let body =
Expr.eapp
~f:(Bindlib.box_var v, Expr.with_ty m fty)
~args:arg_vars ~tys:targs (Expr.with_ty m tret)
in
Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e)
build_closure ctx [] body args targs m
in
Bindlib.unbox boxed
Var.Map.empty, closure
| Some _ -> Var.Map.empty, (Bindlib.box_var v, m))
| EMatch { e; cases; name } ->
let free_vars, new_e = (transform_closures_expr ctx) e in
@ -147,89 +214,15 @@ let rec transform_closures_expr :
~f:(Expr.eabs new_binder (List.map translate_type tys) e1_pos)
~args:new_args ~tys m )
| EAbs { binder; tys } ->
(* λ x.t *)
let binder_pos = Expr.mark_pos m in
let mark_ty ty = Expr.with_ty m ty in
(* Converting the closure. *)
let vars, body = Bindlib.unmbind binder in
(* t *)
let body_vars, new_body = (transform_closures_expr ctx) body in
let free_vars, body = (transform_closures_expr ctx) body in
(* [[t]] *)
let extra_vars =
Array.fold_left (fun m v -> Var.Map.remove v m) body_vars vars
let free_vars =
Array.fold_left (fun m v -> Var.Map.remove v m) free_vars vars
in
let extra_vars_list = Var.Map.bindings extra_vars in
let extra_vars_types =
List.map (fun (_, m) -> Expr.maybe_ty m) extra_vars_list
in
(* x1, ..., xn *)
let code_var = Var.make ctx.name_context in
(* code *)
let closure_env_arg_var = Var.make "env" in
let closure_env_var = Var.make "env" in
let env_ty = TTuple extra_vars_types, binder_pos in
(* let env = from_closure_env env in let arg0 = env.0 in ... *)
let new_closure_body =
Expr.make_let_in closure_env_var env_ty
(Expr.eappop
~op:(Operator.FromClosureEnv, binder_pos)
~tys:[TClosureEnv, binder_pos]
~args:
[Expr.evar closure_env_arg_var (mark_ty (TClosureEnv, binder_pos))]
(mark_ty env_ty))
(Expr.make_multiple_let_in
(Array.of_list (List.map fst extra_vars_list))
extra_vars_types
(List.mapi
(fun i _ ->
Expr.make_tupleaccess
(Expr.evar closure_env_var (mark_ty env_ty))
i
(List.length extra_vars_list)
binder_pos)
extra_vars_list)
new_body binder_pos)
binder_pos
in
(* fun env arg0 ... -> new_closure_body *)
let new_closure =
Expr.make_abs
(Array.append [| closure_env_arg_var |] vars)
new_closure_body
((TClosureEnv, binder_pos) :: tys)
(Expr.pos e)
in
let new_closure_ty = Expr.maybe_ty (Mark.get new_closure) in
( extra_vars,
Expr.make_let_in code_var new_closure_ty new_closure
(Expr.make_tuple
((Bindlib.box_var code_var, mark_ty new_closure_ty)
:: [
Expr.eappop
~op:(Operator.ToClosureEnv, binder_pos)
~tys:
[
( (if extra_vars_list = [] then TLit TUnit
else TTuple extra_vars_types),
binder_pos );
]
~args:
[
(if extra_vars_list = [] then
Expr.elit LUnit (mark_ty (TLit TUnit, binder_pos))
else
Expr.etuple
(List.map
(fun (extra_var, m) ->
( Bindlib.box_var extra_var,
Expr.with_pos binder_pos m ))
extra_vars_list)
(mark_ty (TTuple extra_vars_types, binder_pos)));
]
(mark_ty (TClosureEnv, binder_pos));
])
m)
(Expr.pos e) )
free_vars, build_closure ctx (Var.Map.bindings free_vars) body vars tys m
| EAppOp
{
op = ((HandleDefaultOpt | Fold | Map | Filter | Reduce), _) as op;
@ -318,10 +311,7 @@ let rec transform_closures_expr :
pos)
pos
in
( free_vars,
Expr.make_let_in code_env_var
(TAny, Expr.pos e)
new_e1 call_expr (Expr.pos e) )
free_vars, Expr.make_let_in code_env_var (TAny, pos) new_e1 call_expr pos
| _ -> .
let transform_closures_scope_let ctx scope_body_expr =
@ -674,6 +664,4 @@ let hoist_closures_program (p : 'm program) : 'm program Bindlib.box =
let closure_conversion (p : 'm program) : 'm program =
let new_p = transform_closures_program p in
let new_p = hoist_closures_program (Bindlib.unbox new_p) in
(* FIXME: either fix the types of the marks, or remove the types annotations
during the main processing (rather than requiring a new traversal) *)
Bindlib.unbox new_p

View File

@ -100,7 +100,7 @@ val program :
('a, 'm) gexpr program ->
('a, typed) gexpr program
(** Typing on whole programs (as defined in Shared_ast.program, i.e. for the
later dcalc/lcalc stages.
later dcalc/lcalc stages).
Any existing type annotations are checked for unification. Use
[Program.untype] to remove them beforehand if this is not the desired

View File

@ -41,3 +41,69 @@ let scope S (S_in: S_in {x_in: bool}): S {z: integer} =
return { S z = z; }
```
```catala
declaration scope S2:
output dummy content boolean
input output cfun2 content decimal depends on x content integer
scope S2:
definition dummy equals false
declaration scope S2Use:
internal fun content decimal depends on y content integer
output o content (S2, S2)
declaration fun2 content decimal depends on y content integer equals y / 3
scope S2Use:
definition fun of y equals y / 2
definition o equals
(output of S2 with { -- cfun2: fun },
output of S2 with { -- cfun2: fun2 })
```
```catala-test-inline
$ catala Lcalc --avoid-exceptions -O --closure-conversion -s S2Use
let scope S2Use
(S2Use_in: S2Use_in)
: S2Use {
o:
(S2 {
dummy: bool;
cfun2: ((closure_env, integer) → decimal, closure_env)
},
S2 {
dummy: bool;
cfun2: ((closure_env, integer) → decimal, closure_env)
})
}
=
let set fun : ((closure_env, integer) → decimal, closure_env) =
(closure_fun, to_closure_env ())
in
let set o :
(S2 {
dummy: bool;
cfun2: ((closure_env, integer) → decimal, closure_env)
},
S2 {
dummy: bool;
cfun2: ((closure_env, integer) → decimal, closure_env)
}) =
(let result : S2 = S2 { S2_in cfun2_in = fun; } in
{ S2
dummy = result.dummy;
cfun2 = (closure_o, to_closure_env (result));
},
let result : S2 =
S2 { S2_in cfun2_in = (closure_o, to_closure_env ()); }
in
{ S2
dummy = result.dummy;
cfun2 = (closure_o, to_closure_env (result));
})
in
return { S2Use o = o; }
```