Preserve and propagate types through closure conversion

some of the types (in particular, in hoisted closures) could not be
reconstructed afterwards. This properly propagates the types, including to
closure deconstruction time, giving additional insurance; and allowing
monomorphisation not to choke on the result.
This commit is contained in:
Louis Gesbert 2024-05-30 16:10:21 +02:00
parent 4acf321309
commit 035dff35a3
4 changed files with 100 additions and 87 deletions

View File

@ -24,7 +24,31 @@ type 'm ctx = {
globally_bound_vars : ('m expr, typ) Var.Map.t;
}
let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys
(** Function types will be transformed in this way throughout, including in
[decl_ctx] *)
let rec translate_type t =
let pos = Mark.get t in
match Mark.remove t with
| TArrow (t1, t2) ->
( TTuple
[
( TArrow
( (TClosureEnv, Pos.no_pos) :: List.map translate_type t1,
translate_type t2 ),
Pos.no_pos );
TClosureEnv, Pos.no_pos;
],
pos )
| TDefault t' -> TDefault (translate_type t'), pos
| TOption t' -> TOption (translate_type t'), pos
| TAny | TClosureEnv | TLit _ | TEnum _ | TStruct _ -> t
| TArray ts -> TArray (translate_type ts), pos
| TTuple ts -> TTuple (List.map translate_type ts), pos
let translate_mark e = Mark.map_mark (Expr.map_ty translate_type) e
let join_vars : ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t =
fun m1 m2 -> Var.Map.union (fun _ a _ -> Some a) m1 m2
(** {1 Transforming closures}*)
@ -33,19 +57,20 @@ let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys
http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=10
(environment-passing closure conversion). *)
let rec transform_closures_expr :
type m. m ctx -> m expr -> m expr Var.Set.t * m expr boxed =
type m. m ctx -> m expr -> (m expr, m mark) Var.Map.t * m expr boxed =
fun ctx e ->
let e = translate_mark e in
let m = Mark.get e in
match Mark.remove e with
| EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EArray _
| ELit _ | EExternal _ | EAssert _ | EFatalError _ | EIfThenElse _
| ERaiseEmpty | ECatchEmpty _ ->
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
Expr.map_gather ~acc:Var.Map.empty ~join:join_vars
~f:(transform_closures_expr ctx)
e
| EVar v -> (
match Var.Map.find_opt v ctx.globally_bound_vars with
| None -> Var.Set.singleton v, (Bindlib.box_var v, m)
| None -> Var.Map.singleton v m, (Bindlib.box_var v, m)
| Some (TArrow (targs, tret), _) ->
(* Here we eta-expand the argument to make sure function pointers are
correctly casted as closures *)
@ -69,13 +94,13 @@ let rec transform_closures_expr :
{
ctx with
globally_bound_vars =
Var.Map.add v (TAny, Pos.no_pos) ctx.globally_bound_vars;
Var.Map.add v (Expr.maybe_ty m) ctx.globally_bound_vars;
}
in
Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e)
in
Bindlib.unbox boxed
| Some _ -> Var.Set.empty, (Bindlib.box_var v, m))
| Some _ -> Var.Map.empty, (Bindlib.box_var v, m))
| EMatch { e; cases; name } ->
let free_vars, new_e = (transform_closures_expr ctx) e in
(* We do not close the clotures inside the arms of the match expression,
@ -89,13 +114,11 @@ let rec transform_closures_expr :
let new_free_vars, new_body = (transform_closures_expr ctx) body in
let new_free_vars =
Array.fold_left
(fun acc v -> Var.Set.remove v acc)
(fun acc v -> Var.Map.remove v acc)
new_free_vars vars
in
let new_binder = Expr.bind vars new_body in
( Var.Set.union free_vars
(Var.Set.diff new_free_vars
(Var.Set.of_list (Array.to_list vars))),
( join_vars free_vars new_free_vars,
EnumConstructor.Map.add cons
(Expr.eabs new_binder tys (Mark.get e1))
new_cases )
@ -109,54 +132,58 @@ let rec transform_closures_expr :
let vars, body = Bindlib.unmbind binder in
let free_vars, new_body = (transform_closures_expr ctx) body in
let free_vars =
Array.fold_left (fun acc v -> Var.Set.remove v acc) free_vars vars
Array.fold_left (fun acc v -> Var.Map.remove v acc) free_vars vars
in
let new_binder = Expr.bind vars new_body in
let free_vars, new_args =
List.fold_right
(fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
join_vars free_vars new_free_vars, new_arg :: new_args)
args (free_vars, [])
in
( free_vars,
Expr.eapp
~f:(Expr.eabs new_binder (tys_as_tanys tys) e1_pos)
~f:(Expr.eabs new_binder (List.map translate_type tys) e1_pos)
~args:new_args ~tys m )
| EAbs { binder; tys } ->
(* λ x.t *)
let binder_mark = Expr.with_ty m (TAny, Expr.mark_pos m) in
let binder_pos = Expr.mark_pos binder_mark in
let binder_pos = Expr.mark_pos m in
let mark_ty ty = Expr.with_ty m ty in
(* Converting the closure. *)
let vars, body = Bindlib.unmbind binder in
(* t *)
let body_vars, new_body = (transform_closures_expr ctx) body in
(* [[t]] *)
let extra_vars =
Var.Set.diff body_vars (Var.Set.of_list (Array.to_list vars))
Array.fold_left (fun m v -> Var.Map.remove v m) body_vars vars
in
let extra_vars_list = Var.Map.bindings extra_vars in
let extra_vars_types =
List.map (fun (_, m) -> Expr.maybe_ty m) extra_vars_list
in
let extra_vars_list = Var.Set.elements extra_vars in
(* x1, ..., xn *)
let code_var = Var.make ctx.name_context in
(* code *)
let closure_env_arg_var = Var.make "env" in
let closure_env_var = Var.make "env" in
let any_ty = TAny, binder_pos in
let env_ty = TTuple extra_vars_types, binder_pos in
(* let env = from_closure_env env in let arg0 = env.0 in ... *)
let new_closure_body =
Expr.make_let_in closure_env_var any_ty
Expr.make_let_in closure_env_var env_ty
(Expr.eappop
~op:(Operator.FromClosureEnv, binder_pos)
~tys:[TClosureEnv, binder_pos]
~args:[Expr.evar closure_env_arg_var binder_mark]
binder_mark)
~args:
[Expr.evar closure_env_arg_var (mark_ty (TClosureEnv, binder_pos))]
(mark_ty env_ty))
(Expr.make_multiple_let_in
(Array.of_list extra_vars_list)
(List.map (fun _ -> any_ty) extra_vars_list)
(Array.of_list (List.map fst extra_vars_list))
extra_vars_types
(List.mapi
(fun i _ ->
Expr.make_tupleaccess
(Expr.evar closure_env_var binder_mark)
(Expr.evar closure_env_var (mark_ty env_ty))
i
(List.length extra_vars_list)
binder_pos)
@ -167,33 +194,39 @@ let rec transform_closures_expr :
(* fun env arg0 ... -> new_closure_body *)
let new_closure =
Expr.make_abs
(Array.concat [Array.make 1 closure_env_arg_var; vars])
(Array.append [| closure_env_arg_var |] vars)
new_closure_body
((TClosureEnv, binder_pos) :: tys)
(Expr.pos e)
in
let new_closure_ty = Expr.maybe_ty (Mark.get new_closure) in
( extra_vars,
Expr.make_let_in code_var
(TAny, Expr.pos e)
new_closure
Expr.make_let_in code_var new_closure_ty new_closure
(Expr.make_tuple
((Bindlib.box_var code_var, binder_mark)
((Bindlib.box_var code_var, mark_ty new_closure_ty)
:: [
Expr.eappop
~op:(Operator.ToClosureEnv, binder_pos)
~tys:[TAny, Expr.pos e]
~tys:
[
( (if extra_vars_list = [] then TLit TUnit
else TTuple extra_vars_types),
binder_pos );
]
~args:
[
(if extra_vars_list = [] then Expr.elit LUnit binder_mark
(if extra_vars_list = [] then
Expr.elit LUnit (mark_ty (TLit TUnit, binder_pos))
else
Expr.etuple
(List.map
(fun extra_var ->
Bindlib.box_var extra_var, binder_mark)
(fun (extra_var, m) ->
( Bindlib.box_var extra_var,
Expr.with_pos binder_pos m ))
extra_vars_list)
m);
(mark_ty (TTuple extra_vars_types, binder_pos)));
]
(Mark.get e);
(mark_ty (TClosureEnv, binder_pos));
])
m)
(Expr.pos e) )
@ -219,16 +252,16 @@ let rec transform_closures_expr :
let new_arg =
Expr.make_abs vars new_arg tys (Expr.mark_pos m_arg)
in
Var.Set.union free_vars new_free_vars, new_arg :: new_args
join_vars free_vars new_free_vars, new_arg :: new_args
| _ ->
let new_free_vars, new_arg = transform_closures_expr ctx arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
args (Var.Set.empty, [])
join_vars free_vars new_free_vars, new_arg :: new_args)
args (Var.Map.empty, [])
in
free_vars, Expr.eappop ~op ~tys ~args:new_args (Mark.get e)
| EAppOp _ ->
(* This corresponds to an operator call, which we don't want to transform *)
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
Expr.map_gather ~acc:Var.Map.empty ~join:join_vars
~f:(transform_closures_expr ctx)
e
| EApp { f = EVar v, f_m; args; tys }
@ -239,12 +272,16 @@ let rec transform_closures_expr :
List.fold_right
(fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
args (Var.Set.empty, [])
join_vars free_vars new_free_vars, new_arg :: new_args)
args (Var.Map.empty, [])
in
free_vars, Expr.eapp ~f:(Expr.evar v f_m) ~args:new_args ~tys m
| EApp { f = e1; args; tys } ->
let free_vars, new_e1 = (transform_closures_expr ctx) e1 in
let tys = List.map translate_type tys in
let pos = Expr.mark_pos m in
let env_arg_ty = TClosureEnv, Expr.pos new_e1 in
let fun_ty = TArrow (env_arg_ty :: tys, Expr.maybe_ty m), pos in
let code_env_var = Var.make "code_and_env" in
let code_env_expr =
let pos = Expr.pos e1 in
@ -252,8 +289,7 @@ let rec transform_closures_expr :
(Expr.with_ty (Mark.get e1)
( TTuple
[
( TArrow ((TClosureEnv, pos) :: tys, (TAny, Expr.pos e)),
Expr.pos e );
TArrow ((TClosureEnv, pos) :: tys, Expr.maybe_ty m), Expr.pos e;
TClosureEnv, pos;
],
pos ))
@ -264,24 +300,23 @@ let rec transform_closures_expr :
List.fold_right
(fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
join_vars free_vars new_free_vars, new_arg :: new_args)
args (free_vars, [])
in
let call_expr =
let m1 = Mark.get e1 in
let pos = Expr.mark_pos m in
let env_arg_ty = TClosureEnv, Expr.pos e1 in
let fun_ty = TArrow (env_arg_ty :: tys, (TAny, Expr.pos e)), Expr.pos e in
let m1 = Mark.get new_e1 in
Expr.make_multiple_let_in [| code_var; env_var |] [fun_ty; env_arg_ty]
[
Expr.make_tupleaccess code_env_expr 0 2 pos;
Expr.make_tupleaccess code_env_expr 1 2 pos;
]
(Expr.eapp
~f:(Bindlib.box_var code_var, m1)
~args:((Bindlib.box_var env_var, m1) :: new_args)
~tys:(env_arg_ty :: tys) m)
(Expr.pos e)
(Expr.make_app
(Bindlib.box_var code_var, Expr.with_ty m1 fun_ty)
((Bindlib.box_var env_var, Expr.with_ty m1 env_arg_ty) :: new_args)
(env_arg_ty
:: (* List.map (fun (_, m) -> Expr.maybe_ty m) new_args *) tys)
pos)
pos
in
( free_vars,
Expr.make_let_in code_env_var
@ -393,33 +428,15 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
capture footprint. See
[tests/tests_func/good/scope_call_func_struct_closure.catala_en]. *)
let new_decl_ctx =
let rec replace_fun_typs t =
match Mark.remove t with
| TArrow (t1, t2) ->
( TTuple
[
( TArrow
( (TClosureEnv, Pos.no_pos) :: List.map replace_fun_typs t1,
replace_fun_typs t2 ),
Pos.no_pos );
TClosureEnv, Pos.no_pos;
],
Mark.get t )
| TDefault t' -> TDefault (replace_fun_typs t'), Mark.get t
| TOption t' -> TOption (replace_fun_typs t'), Mark.get t
| TAny | TClosureEnv | TLit _ | TEnum _ | TStruct _ -> t
| TArray ts -> TArray (replace_fun_typs ts), Mark.get t
| TTuple ts -> TTuple (List.map replace_fun_typs ts), Mark.get t
in
{
p.decl_ctx with
ctx_structs =
StructName.Map.map
(StructField.Map.map replace_fun_typs)
(StructField.Map.map translate_type)
p.decl_ctx.ctx_structs;
ctx_enums =
EnumName.Map.map
(EnumConstructor.Map.map replace_fun_typs)
(EnumConstructor.Map.map translate_type)
p.decl_ctx.ctx_enums;
(* Toplevel definitions may not contain scope calls or take functions as
arguments at the moment, which ensures that their interfaces aren't
@ -489,9 +506,7 @@ let rec hoist_closures_expr :
args (collected_closures, [])
in
( collected_closures,
Expr.eapp
~f:(Expr.eabs new_binder (tys_as_tanys tys) e1_pos)
~args:new_args ~tys m )
Expr.eapp ~f:(Expr.eabs new_binder tys e1_pos) ~args:new_args ~tys m )
| EAppOp
{
op = ((HandleDefaultOpt | Fold | Map | Filter | Reduce), _) as op;
@ -532,13 +547,9 @@ let rec hoist_closures_expr :
to name wrangling in the backends. We need to have a better system for
name disambiguation when for instance printing to Dcalc/Lcalc/Scalc but
also OCaml, Python, etc. *)
( [
{
name = closure_var;
ty = TArrow (tys, (TAny, Expr.mark_pos m)), Expr.mark_pos m;
closure = Expr.rebox e;
};
],
let pos = Expr.mark_pos m in
let ty = Expr.maybe_ty ~typ:(TArrow (tys, (TAny, pos))) m in
( [{ name = closure_var; ty; closure = Expr.rebox e }],
Expr.make_var closure_var m )
| EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
| EArray _ | ELit _ | EAssert _ | EFatalError _ | EAppOp _ | EIfThenElse _
@ -660,9 +671,9 @@ let hoist_closures_program (p : 'm program) : 'm program Bindlib.box =
(** {1 Closure conversion}*)
let closure_conversion (p : 'm program) : untyped program =
let closure_conversion (p : 'm program) : 'm program =
let new_p = transform_closures_program p in
let new_p = hoist_closures_program (Bindlib.unbox new_p) in
(* FIXME: either fix the types of the marks, or remove the types annotations
during the main processing (rather than requiring a new traversal) *)
Program.untype (Bindlib.unbox new_p)
Bindlib.unbox new_p

View File

@ -21,4 +21,4 @@
After closure conversion, closure hoisting is perform and all closures end
up as toplevel definitions. *)
val closure_conversion : 'm Ast.program -> Shared_ast.untyped Ast.program
val closure_conversion : 'm Ast.program -> 'm Ast.program

View File

@ -100,6 +100,7 @@ module Map = struct
let empty = empty
let singleton v x = singleton (t v) x
let add v x m = add (t v) x m
let remove v m = remove (t v) m
let update v f m = update (t v) f m
let find v m = find (t v) m
let find_opt v m = find_opt (t v) m

View File

@ -64,6 +64,7 @@ module Map : sig
val empty : ('e, 'x) t
val singleton : 'e var -> 'x -> ('e, 'x) t
val add : 'e var -> 'x -> ('e, 'x) t -> ('e, 'x) t
val remove : 'e var -> ('e, 'x) t -> ('e, 'x) t
val update : 'e var -> ('x option -> 'x option) -> ('e, 'x) t -> ('e, 'x) t
val find : 'e var -> ('e, 'x) t -> 'x
val find_opt : 'e var -> ('e, 'x) t -> 'x option