mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Preserve and propagate types through closure conversion
some of the types (in particular, in hoisted closures) could not be reconstructed afterwards. This properly propagates the types, including to closure deconstruction time, giving additional insurance; and allowing monomorphisation not to choke on the result.
This commit is contained in:
parent
4acf321309
commit
035dff35a3
@ -24,7 +24,31 @@ type 'm ctx = {
|
||||
globally_bound_vars : ('m expr, typ) Var.Map.t;
|
||||
}
|
||||
|
||||
let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys
|
||||
(** Function types will be transformed in this way throughout, including in
|
||||
[decl_ctx] *)
|
||||
let rec translate_type t =
|
||||
let pos = Mark.get t in
|
||||
match Mark.remove t with
|
||||
| TArrow (t1, t2) ->
|
||||
( TTuple
|
||||
[
|
||||
( TArrow
|
||||
( (TClosureEnv, Pos.no_pos) :: List.map translate_type t1,
|
||||
translate_type t2 ),
|
||||
Pos.no_pos );
|
||||
TClosureEnv, Pos.no_pos;
|
||||
],
|
||||
pos )
|
||||
| TDefault t' -> TDefault (translate_type t'), pos
|
||||
| TOption t' -> TOption (translate_type t'), pos
|
||||
| TAny | TClosureEnv | TLit _ | TEnum _ | TStruct _ -> t
|
||||
| TArray ts -> TArray (translate_type ts), pos
|
||||
| TTuple ts -> TTuple (List.map translate_type ts), pos
|
||||
|
||||
let translate_mark e = Mark.map_mark (Expr.map_ty translate_type) e
|
||||
|
||||
let join_vars : ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t -> ('a, 'x) Var.Map.t =
|
||||
fun m1 m2 -> Var.Map.union (fun _ a _ -> Some a) m1 m2
|
||||
|
||||
(** {1 Transforming closures}*)
|
||||
|
||||
@ -33,19 +57,20 @@ let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys
|
||||
http://gallium.inria.fr/~fpottier/mpri/cours04.pdf#page=10
|
||||
(environment-passing closure conversion). *)
|
||||
let rec transform_closures_expr :
|
||||
type m. m ctx -> m expr -> m expr Var.Set.t * m expr boxed =
|
||||
type m. m ctx -> m expr -> (m expr, m mark) Var.Map.t * m expr boxed =
|
||||
fun ctx e ->
|
||||
let e = translate_mark e in
|
||||
let m = Mark.get e in
|
||||
match Mark.remove e with
|
||||
| EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EArray _
|
||||
| ELit _ | EExternal _ | EAssert _ | EFatalError _ | EIfThenElse _
|
||||
| ERaiseEmpty | ECatchEmpty _ ->
|
||||
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
|
||||
Expr.map_gather ~acc:Var.Map.empty ~join:join_vars
|
||||
~f:(transform_closures_expr ctx)
|
||||
e
|
||||
| EVar v -> (
|
||||
match Var.Map.find_opt v ctx.globally_bound_vars with
|
||||
| None -> Var.Set.singleton v, (Bindlib.box_var v, m)
|
||||
| None -> Var.Map.singleton v m, (Bindlib.box_var v, m)
|
||||
| Some (TArrow (targs, tret), _) ->
|
||||
(* Here we eta-expand the argument to make sure function pointers are
|
||||
correctly casted as closures *)
|
||||
@ -69,13 +94,13 @@ let rec transform_closures_expr :
|
||||
{
|
||||
ctx with
|
||||
globally_bound_vars =
|
||||
Var.Map.add v (TAny, Pos.no_pos) ctx.globally_bound_vars;
|
||||
Var.Map.add v (Expr.maybe_ty m) ctx.globally_bound_vars;
|
||||
}
|
||||
in
|
||||
Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e)
|
||||
in
|
||||
Bindlib.unbox boxed
|
||||
| Some _ -> Var.Set.empty, (Bindlib.box_var v, m))
|
||||
| Some _ -> Var.Map.empty, (Bindlib.box_var v, m))
|
||||
| EMatch { e; cases; name } ->
|
||||
let free_vars, new_e = (transform_closures_expr ctx) e in
|
||||
(* We do not close the clotures inside the arms of the match expression,
|
||||
@ -89,13 +114,11 @@ let rec transform_closures_expr :
|
||||
let new_free_vars, new_body = (transform_closures_expr ctx) body in
|
||||
let new_free_vars =
|
||||
Array.fold_left
|
||||
(fun acc v -> Var.Set.remove v acc)
|
||||
(fun acc v -> Var.Map.remove v acc)
|
||||
new_free_vars vars
|
||||
in
|
||||
let new_binder = Expr.bind vars new_body in
|
||||
( Var.Set.union free_vars
|
||||
(Var.Set.diff new_free_vars
|
||||
(Var.Set.of_list (Array.to_list vars))),
|
||||
( join_vars free_vars new_free_vars,
|
||||
EnumConstructor.Map.add cons
|
||||
(Expr.eabs new_binder tys (Mark.get e1))
|
||||
new_cases )
|
||||
@ -109,54 +132,58 @@ let rec transform_closures_expr :
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let free_vars, new_body = (transform_closures_expr ctx) body in
|
||||
let free_vars =
|
||||
Array.fold_left (fun acc v -> Var.Set.remove v acc) free_vars vars
|
||||
Array.fold_left (fun acc v -> Var.Map.remove v acc) free_vars vars
|
||||
in
|
||||
let new_binder = Expr.bind vars new_body in
|
||||
let free_vars, new_args =
|
||||
List.fold_right
|
||||
(fun arg (free_vars, new_args) ->
|
||||
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
|
||||
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
|
||||
join_vars free_vars new_free_vars, new_arg :: new_args)
|
||||
args (free_vars, [])
|
||||
in
|
||||
( free_vars,
|
||||
Expr.eapp
|
||||
~f:(Expr.eabs new_binder (tys_as_tanys tys) e1_pos)
|
||||
~f:(Expr.eabs new_binder (List.map translate_type tys) e1_pos)
|
||||
~args:new_args ~tys m )
|
||||
| EAbs { binder; tys } ->
|
||||
(* λ x.t *)
|
||||
let binder_mark = Expr.with_ty m (TAny, Expr.mark_pos m) in
|
||||
let binder_pos = Expr.mark_pos binder_mark in
|
||||
let binder_pos = Expr.mark_pos m in
|
||||
let mark_ty ty = Expr.with_ty m ty in
|
||||
(* Converting the closure. *)
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
(* t *)
|
||||
let body_vars, new_body = (transform_closures_expr ctx) body in
|
||||
(* [[t]] *)
|
||||
let extra_vars =
|
||||
Var.Set.diff body_vars (Var.Set.of_list (Array.to_list vars))
|
||||
Array.fold_left (fun m v -> Var.Map.remove v m) body_vars vars
|
||||
in
|
||||
let extra_vars_list = Var.Map.bindings extra_vars in
|
||||
let extra_vars_types =
|
||||
List.map (fun (_, m) -> Expr.maybe_ty m) extra_vars_list
|
||||
in
|
||||
let extra_vars_list = Var.Set.elements extra_vars in
|
||||
(* x1, ..., xn *)
|
||||
let code_var = Var.make ctx.name_context in
|
||||
(* code *)
|
||||
let closure_env_arg_var = Var.make "env" in
|
||||
let closure_env_var = Var.make "env" in
|
||||
let any_ty = TAny, binder_pos in
|
||||
let env_ty = TTuple extra_vars_types, binder_pos in
|
||||
(* let env = from_closure_env env in let arg0 = env.0 in ... *)
|
||||
let new_closure_body =
|
||||
Expr.make_let_in closure_env_var any_ty
|
||||
Expr.make_let_in closure_env_var env_ty
|
||||
(Expr.eappop
|
||||
~op:(Operator.FromClosureEnv, binder_pos)
|
||||
~tys:[TClosureEnv, binder_pos]
|
||||
~args:[Expr.evar closure_env_arg_var binder_mark]
|
||||
binder_mark)
|
||||
~args:
|
||||
[Expr.evar closure_env_arg_var (mark_ty (TClosureEnv, binder_pos))]
|
||||
(mark_ty env_ty))
|
||||
(Expr.make_multiple_let_in
|
||||
(Array.of_list extra_vars_list)
|
||||
(List.map (fun _ -> any_ty) extra_vars_list)
|
||||
(Array.of_list (List.map fst extra_vars_list))
|
||||
extra_vars_types
|
||||
(List.mapi
|
||||
(fun i _ ->
|
||||
Expr.make_tupleaccess
|
||||
(Expr.evar closure_env_var binder_mark)
|
||||
(Expr.evar closure_env_var (mark_ty env_ty))
|
||||
i
|
||||
(List.length extra_vars_list)
|
||||
binder_pos)
|
||||
@ -167,33 +194,39 @@ let rec transform_closures_expr :
|
||||
(* fun env arg0 ... -> new_closure_body *)
|
||||
let new_closure =
|
||||
Expr.make_abs
|
||||
(Array.concat [Array.make 1 closure_env_arg_var; vars])
|
||||
(Array.append [| closure_env_arg_var |] vars)
|
||||
new_closure_body
|
||||
((TClosureEnv, binder_pos) :: tys)
|
||||
(Expr.pos e)
|
||||
in
|
||||
let new_closure_ty = Expr.maybe_ty (Mark.get new_closure) in
|
||||
( extra_vars,
|
||||
Expr.make_let_in code_var
|
||||
(TAny, Expr.pos e)
|
||||
new_closure
|
||||
Expr.make_let_in code_var new_closure_ty new_closure
|
||||
(Expr.make_tuple
|
||||
((Bindlib.box_var code_var, binder_mark)
|
||||
((Bindlib.box_var code_var, mark_ty new_closure_ty)
|
||||
:: [
|
||||
Expr.eappop
|
||||
~op:(Operator.ToClosureEnv, binder_pos)
|
||||
~tys:[TAny, Expr.pos e]
|
||||
~tys:
|
||||
[
|
||||
( (if extra_vars_list = [] then TLit TUnit
|
||||
else TTuple extra_vars_types),
|
||||
binder_pos );
|
||||
]
|
||||
~args:
|
||||
[
|
||||
(if extra_vars_list = [] then Expr.elit LUnit binder_mark
|
||||
(if extra_vars_list = [] then
|
||||
Expr.elit LUnit (mark_ty (TLit TUnit, binder_pos))
|
||||
else
|
||||
Expr.etuple
|
||||
(List.map
|
||||
(fun extra_var ->
|
||||
Bindlib.box_var extra_var, binder_mark)
|
||||
(fun (extra_var, m) ->
|
||||
( Bindlib.box_var extra_var,
|
||||
Expr.with_pos binder_pos m ))
|
||||
extra_vars_list)
|
||||
m);
|
||||
(mark_ty (TTuple extra_vars_types, binder_pos)));
|
||||
]
|
||||
(Mark.get e);
|
||||
(mark_ty (TClosureEnv, binder_pos));
|
||||
])
|
||||
m)
|
||||
(Expr.pos e) )
|
||||
@ -219,16 +252,16 @@ let rec transform_closures_expr :
|
||||
let new_arg =
|
||||
Expr.make_abs vars new_arg tys (Expr.mark_pos m_arg)
|
||||
in
|
||||
Var.Set.union free_vars new_free_vars, new_arg :: new_args
|
||||
join_vars free_vars new_free_vars, new_arg :: new_args
|
||||
| _ ->
|
||||
let new_free_vars, new_arg = transform_closures_expr ctx arg in
|
||||
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
|
||||
args (Var.Set.empty, [])
|
||||
join_vars free_vars new_free_vars, new_arg :: new_args)
|
||||
args (Var.Map.empty, [])
|
||||
in
|
||||
free_vars, Expr.eappop ~op ~tys ~args:new_args (Mark.get e)
|
||||
| EAppOp _ ->
|
||||
(* This corresponds to an operator call, which we don't want to transform *)
|
||||
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
|
||||
Expr.map_gather ~acc:Var.Map.empty ~join:join_vars
|
||||
~f:(transform_closures_expr ctx)
|
||||
e
|
||||
| EApp { f = EVar v, f_m; args; tys }
|
||||
@ -239,12 +272,16 @@ let rec transform_closures_expr :
|
||||
List.fold_right
|
||||
(fun arg (free_vars, new_args) ->
|
||||
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
|
||||
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
|
||||
args (Var.Set.empty, [])
|
||||
join_vars free_vars new_free_vars, new_arg :: new_args)
|
||||
args (Var.Map.empty, [])
|
||||
in
|
||||
free_vars, Expr.eapp ~f:(Expr.evar v f_m) ~args:new_args ~tys m
|
||||
| EApp { f = e1; args; tys } ->
|
||||
let free_vars, new_e1 = (transform_closures_expr ctx) e1 in
|
||||
let tys = List.map translate_type tys in
|
||||
let pos = Expr.mark_pos m in
|
||||
let env_arg_ty = TClosureEnv, Expr.pos new_e1 in
|
||||
let fun_ty = TArrow (env_arg_ty :: tys, Expr.maybe_ty m), pos in
|
||||
let code_env_var = Var.make "code_and_env" in
|
||||
let code_env_expr =
|
||||
let pos = Expr.pos e1 in
|
||||
@ -252,8 +289,7 @@ let rec transform_closures_expr :
|
||||
(Expr.with_ty (Mark.get e1)
|
||||
( TTuple
|
||||
[
|
||||
( TArrow ((TClosureEnv, pos) :: tys, (TAny, Expr.pos e)),
|
||||
Expr.pos e );
|
||||
TArrow ((TClosureEnv, pos) :: tys, Expr.maybe_ty m), Expr.pos e;
|
||||
TClosureEnv, pos;
|
||||
],
|
||||
pos ))
|
||||
@ -264,24 +300,23 @@ let rec transform_closures_expr :
|
||||
List.fold_right
|
||||
(fun arg (free_vars, new_args) ->
|
||||
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
|
||||
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
|
||||
join_vars free_vars new_free_vars, new_arg :: new_args)
|
||||
args (free_vars, [])
|
||||
in
|
||||
let call_expr =
|
||||
let m1 = Mark.get e1 in
|
||||
let pos = Expr.mark_pos m in
|
||||
let env_arg_ty = TClosureEnv, Expr.pos e1 in
|
||||
let fun_ty = TArrow (env_arg_ty :: tys, (TAny, Expr.pos e)), Expr.pos e in
|
||||
let m1 = Mark.get new_e1 in
|
||||
Expr.make_multiple_let_in [| code_var; env_var |] [fun_ty; env_arg_ty]
|
||||
[
|
||||
Expr.make_tupleaccess code_env_expr 0 2 pos;
|
||||
Expr.make_tupleaccess code_env_expr 1 2 pos;
|
||||
]
|
||||
(Expr.eapp
|
||||
~f:(Bindlib.box_var code_var, m1)
|
||||
~args:((Bindlib.box_var env_var, m1) :: new_args)
|
||||
~tys:(env_arg_ty :: tys) m)
|
||||
(Expr.pos e)
|
||||
(Expr.make_app
|
||||
(Bindlib.box_var code_var, Expr.with_ty m1 fun_ty)
|
||||
((Bindlib.box_var env_var, Expr.with_ty m1 env_arg_ty) :: new_args)
|
||||
(env_arg_ty
|
||||
:: (* List.map (fun (_, m) -> Expr.maybe_ty m) new_args *) tys)
|
||||
pos)
|
||||
pos
|
||||
in
|
||||
( free_vars,
|
||||
Expr.make_let_in code_env_var
|
||||
@ -393,33 +428,15 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
|
||||
capture footprint. See
|
||||
[tests/tests_func/good/scope_call_func_struct_closure.catala_en]. *)
|
||||
let new_decl_ctx =
|
||||
let rec replace_fun_typs t =
|
||||
match Mark.remove t with
|
||||
| TArrow (t1, t2) ->
|
||||
( TTuple
|
||||
[
|
||||
( TArrow
|
||||
( (TClosureEnv, Pos.no_pos) :: List.map replace_fun_typs t1,
|
||||
replace_fun_typs t2 ),
|
||||
Pos.no_pos );
|
||||
TClosureEnv, Pos.no_pos;
|
||||
],
|
||||
Mark.get t )
|
||||
| TDefault t' -> TDefault (replace_fun_typs t'), Mark.get t
|
||||
| TOption t' -> TOption (replace_fun_typs t'), Mark.get t
|
||||
| TAny | TClosureEnv | TLit _ | TEnum _ | TStruct _ -> t
|
||||
| TArray ts -> TArray (replace_fun_typs ts), Mark.get t
|
||||
| TTuple ts -> TTuple (List.map replace_fun_typs ts), Mark.get t
|
||||
in
|
||||
{
|
||||
p.decl_ctx with
|
||||
ctx_structs =
|
||||
StructName.Map.map
|
||||
(StructField.Map.map replace_fun_typs)
|
||||
(StructField.Map.map translate_type)
|
||||
p.decl_ctx.ctx_structs;
|
||||
ctx_enums =
|
||||
EnumName.Map.map
|
||||
(EnumConstructor.Map.map replace_fun_typs)
|
||||
(EnumConstructor.Map.map translate_type)
|
||||
p.decl_ctx.ctx_enums;
|
||||
(* Toplevel definitions may not contain scope calls or take functions as
|
||||
arguments at the moment, which ensures that their interfaces aren't
|
||||
@ -489,9 +506,7 @@ let rec hoist_closures_expr :
|
||||
args (collected_closures, [])
|
||||
in
|
||||
( collected_closures,
|
||||
Expr.eapp
|
||||
~f:(Expr.eabs new_binder (tys_as_tanys tys) e1_pos)
|
||||
~args:new_args ~tys m )
|
||||
Expr.eapp ~f:(Expr.eabs new_binder tys e1_pos) ~args:new_args ~tys m )
|
||||
| EAppOp
|
||||
{
|
||||
op = ((HandleDefaultOpt | Fold | Map | Filter | Reduce), _) as op;
|
||||
@ -532,13 +547,9 @@ let rec hoist_closures_expr :
|
||||
to name wrangling in the backends. We need to have a better system for
|
||||
name disambiguation when for instance printing to Dcalc/Lcalc/Scalc but
|
||||
also OCaml, Python, etc. *)
|
||||
( [
|
||||
{
|
||||
name = closure_var;
|
||||
ty = TArrow (tys, (TAny, Expr.mark_pos m)), Expr.mark_pos m;
|
||||
closure = Expr.rebox e;
|
||||
};
|
||||
],
|
||||
let pos = Expr.mark_pos m in
|
||||
let ty = Expr.maybe_ty ~typ:(TArrow (tys, (TAny, pos))) m in
|
||||
( [{ name = closure_var; ty; closure = Expr.rebox e }],
|
||||
Expr.make_var closure_var m )
|
||||
| EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
|
||||
| EArray _ | ELit _ | EAssert _ | EFatalError _ | EAppOp _ | EIfThenElse _
|
||||
@ -660,9 +671,9 @@ let hoist_closures_program (p : 'm program) : 'm program Bindlib.box =
|
||||
|
||||
(** {1 Closure conversion}*)
|
||||
|
||||
let closure_conversion (p : 'm program) : untyped program =
|
||||
let closure_conversion (p : 'm program) : 'm program =
|
||||
let new_p = transform_closures_program p in
|
||||
let new_p = hoist_closures_program (Bindlib.unbox new_p) in
|
||||
(* FIXME: either fix the types of the marks, or remove the types annotations
|
||||
during the main processing (rather than requiring a new traversal) *)
|
||||
Program.untype (Bindlib.unbox new_p)
|
||||
Bindlib.unbox new_p
|
||||
|
@ -21,4 +21,4 @@
|
||||
After closure conversion, closure hoisting is perform and all closures end
|
||||
up as toplevel definitions. *)
|
||||
|
||||
val closure_conversion : 'm Ast.program -> Shared_ast.untyped Ast.program
|
||||
val closure_conversion : 'm Ast.program -> 'm Ast.program
|
||||
|
@ -100,6 +100,7 @@ module Map = struct
|
||||
let empty = empty
|
||||
let singleton v x = singleton (t v) x
|
||||
let add v x m = add (t v) x m
|
||||
let remove v m = remove (t v) m
|
||||
let update v f m = update (t v) f m
|
||||
let find v m = find (t v) m
|
||||
let find_opt v m = find_opt (t v) m
|
||||
|
@ -64,6 +64,7 @@ module Map : sig
|
||||
val empty : ('e, 'x) t
|
||||
val singleton : 'e var -> 'x -> ('e, 'x) t
|
||||
val add : 'e var -> 'x -> ('e, 'x) t -> ('e, 'x) t
|
||||
val remove : 'e var -> ('e, 'x) t -> ('e, 'x) t
|
||||
val update : 'e var -> ('x option -> 'x option) -> ('e, 'x) t -> ('e, 'x) t
|
||||
val find : 'e var -> ('e, 'x) t -> 'x
|
||||
val find_opt : 'e var -> ('e, 'x) t -> 'x option
|
||||
|
Loading…
Reference in New Issue
Block a user