Fix closure-conversion

Joint debugging with @denismerigoux :)
This commit is contained in:
Louis Gesbert 2023-11-28 11:11:33 +01:00
parent 80475ad5ef
commit 645c263ccc
2 changed files with 116 additions and 28 deletions

View File

@ -22,7 +22,7 @@ module D = Dcalc.Ast
type 'm ctx = {
decl_ctx : decl_ctx;
name_context : string;
globally_bound_vars : 'm expr Var.Set.t;
globally_bound_vars : ('m expr, typ) Var.Map.t;
}
let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys
@ -44,10 +44,38 @@ let rec transform_closures_expr :
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
~f:(transform_closures_expr ctx)
e
| EVar v ->
( (if Var.Set.mem v ctx.globally_bound_vars then Var.Set.empty
else Var.Set.singleton v),
(Bindlib.box_var v, m) )
| EVar v -> (
match Var.Map.find_opt v ctx.globally_bound_vars with
| None -> Var.Set.singleton v, (Bindlib.box_var v, m)
| Some (TArrow (targs, tret), _) ->
(* Here we eta-expand the argument to make sure function pointers are
correctly casted as closures *)
let args = Array.init (List.length targs) (fun _ -> Var.make "eta_arg") in
let arg_vars =
List.map2
(fun v ty -> Expr.evar v (Expr.with_ty m ty))
(Array.to_list args) targs
in
let e =
Expr.eabs
(Expr.bind args
(Expr.eapp (Expr.rebox e) arg_vars (Expr.with_ty m tret)))
targs m
in
let boxed =
let ctx =
(* We hide the type of the toplevel definition so that the function
doesn't loop *)
{
ctx with
globally_bound_vars =
Var.Map.add v (TAny, Pos.no_pos) ctx.globally_bound_vars;
}
in
Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e)
in
Bindlib.unbox boxed
| Some _ -> Var.Set.empty, (Bindlib.box_var v, m))
| EMatch { e; cases; name } ->
let free_vars, new_e = (transform_closures_expr ctx) e in
(* We do not close the clotures inside the arms of the match expression,
@ -59,6 +87,11 @@ let rec transform_closures_expr :
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let new_free_vars, new_body = (transform_closures_expr ctx) body in
let new_free_vars =
Array.fold_left
(fun acc v -> Var.Set.remove v acc)
new_free_vars vars
in
let new_binder = Expr.bind vars new_body in
( Var.Set.union free_vars
(Var.Set.diff new_free_vars
@ -75,6 +108,9 @@ let rec transform_closures_expr :
(* let-binding, we should not close these *)
let vars, body = Bindlib.unmbind binder in
let free_vars, new_body = (transform_closures_expr ctx) body in
let free_vars =
Array.fold_left (fun acc v -> Var.Set.remove v acc) free_vars vars
in
let new_binder = Expr.bind vars new_body in
let free_vars, new_args =
List.fold_right
@ -195,11 +231,17 @@ let rec transform_closures_expr :
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
~f:(transform_closures_expr ctx)
e
| EApp { f = EVar v, _; _ } when Var.Set.mem v ctx.globally_bound_vars ->
(* This corresponds to a scope call, which we don't want to transform*)
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
~f:(transform_closures_expr ctx)
e
| EApp { f = EVar v, f_m; args } when Var.Map.mem v ctx.globally_bound_vars ->
(* This corresponds to a scope or toplevel function call, which we don't
want to transform*)
let free_vars, new_args =
List.fold_right
(fun arg (free_vars, new_args) ->
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
args (Var.Set.empty, [])
in
free_vars, Expr.eapp (Expr.evar v f_m) new_args m
| EApp { f = e1; args } ->
let free_vars, new_e1 = (transform_closures_expr ctx) e1 in
let code_env_var = Var.make "code_and_env" in
@ -286,12 +328,33 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
let new_scope_body_expr =
Bindlib.bind_var scope_input_var new_scope_lets
in
( Var.Set.add var toplevel_vars,
let ty =
let pos = Mark.get (ScopeName.get_info name) in
( TArrow
( [TStruct body.scope_body_input_struct, pos],
(TStruct body.scope_body_output_struct, pos) ),
pos )
in
( Var.Map.add var ty toplevel_vars,
Bindlib.box_apply
(fun scope_body_expr ->
ScopeDef (name, { body with scope_body_expr }))
new_scope_body_expr )
| Topdef (name, ty, (EAbs { binder; tys }, m)) ->
let v, expr = Bindlib.unmbind binder in
let ctx =
{
decl_ctx = p.decl_ctx;
name_context = Mark.remove (TopdefName.get_info name);
globally_bound_vars = toplevel_vars;
}
in
let _free_vars, new_expr = transform_closures_expr ctx expr in
let new_binder = Expr.bind v new_expr in
( Var.Map.add var ty toplevel_vars,
Bindlib.box_apply
(fun e -> Topdef (name, ty, e))
(Expr.Box.lift (Expr.eabs new_binder tys m)) )
| Topdef (name, ty, expr) ->
let ctx =
{
@ -301,12 +364,12 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
}
in
let _free_vars, new_expr = transform_closures_expr ctx expr in
( Var.Set.add var toplevel_vars,
( Var.Map.add var ty toplevel_vars,
Bindlib.box_apply
(fun e -> Topdef (name, ty, e))
(fun e -> Topdef (name, (TAny, Mark.get ty), e))
(Expr.Box.lift new_expr) ))
~varf:(fun v -> v)
Var.Set.empty p.code_items
Var.Map.empty p.code_items
in
(* Now we need to further tweak [decl_ctx] because some of the user-defined
types can have closures in them and these closured might have changed type.
@ -342,17 +405,26 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
let replace_fun_typs t =
if type_contains_arrow t then Mark.copy t TAny else t
in
{
p.decl_ctx with
ctx_structs =
StructName.Map.map
(StructField.Map.map replace_fun_typs)
p.decl_ctx.ctx_structs;
ctx_enums =
EnumName.Map.map
(EnumConstructor.Map.map replace_fun_typs)
p.decl_ctx.ctx_enums;
}
let rec convert_ctx ctx =
{
ctx_struct_fields = ctx.ctx_struct_fields;
ctx_modules = ModuleName.Map.map convert_ctx ctx.ctx_modules;
ctx_structs =
StructName.Map.map
(StructField.Map.map replace_fun_typs)
ctx.ctx_structs;
ctx_enums =
EnumName.Map.map
(EnumConstructor.Map.map replace_fun_typs)
ctx.ctx_enums;
ctx_scopes = ctx.ctx_scopes;
ctx_topdefs = ctx.ctx_topdefs;
(* Toplevel definitions may not contain scope calls or take functions as
arguments at the moment, which ensures that their interfaces aren't
changed by the conversion *)
}
in
convert_ctx p.decl_ctx
in
Bindlib.box_apply
(fun new_code_items ->
@ -528,13 +600,23 @@ let rec hoist_closures_code_item_list
(fun scope_body_expr ->
ScopeDef (name, { body with scope_body_expr }))
new_scope_body_expr )
| Topdef (name, ty, (EAbs { binder; tys }, m)) ->
let v, expr = Bindlib.unmbind binder in
let new_hoisted_closures, new_expr =
hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr
in
let new_binder = Expr.bind v new_expr in
( new_hoisted_closures,
Bindlib.box_apply
(fun e -> Topdef (name, ty, e))
(Expr.Box.lift (Expr.eabs new_binder tys m)) )
| Topdef (name, ty, expr) ->
let new_hoisted_closures, new_expr =
hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr
in
( new_hoisted_closures,
Bindlib.box_apply
(fun e -> Topdef (name, ty, e))
(fun e -> Topdef (name, (TAny, Mark.get ty), e))
(Expr.Box.lift new_expr) )
in
let next_code_items = hoist_closures_code_item_list next_code_items in

View File

@ -59,7 +59,13 @@ let scope S (S_in: S_in {x_in: collection integer}): S {y: integer} =
ESome
reduce
(λ (potential_max_1: integer) (potential_max_2: integer) →
if potential_max_1 < potential_max_2 then potential_max_1
if
(let potential_max : integer = potential_max_1 in
potential_max)
< let potential_max : integer = potential_max_2 in
potential_max
then
potential_max_1
else potential_max_2)
-1
x) ]