mirror of
https://github.com/CatalaLang/catala.git
synced 2024-09-19 16:28:12 +03:00
Fix closure-conversion
Joint debugging with @denismerigoux :)
This commit is contained in:
parent
80475ad5ef
commit
645c263ccc
@ -22,7 +22,7 @@ module D = Dcalc.Ast
|
||||
type 'm ctx = {
|
||||
decl_ctx : decl_ctx;
|
||||
name_context : string;
|
||||
globally_bound_vars : 'm expr Var.Set.t;
|
||||
globally_bound_vars : ('m expr, typ) Var.Map.t;
|
||||
}
|
||||
|
||||
let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys
|
||||
@ -44,10 +44,38 @@ let rec transform_closures_expr :
|
||||
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
|
||||
~f:(transform_closures_expr ctx)
|
||||
e
|
||||
| EVar v ->
|
||||
( (if Var.Set.mem v ctx.globally_bound_vars then Var.Set.empty
|
||||
else Var.Set.singleton v),
|
||||
(Bindlib.box_var v, m) )
|
||||
| EVar v -> (
|
||||
match Var.Map.find_opt v ctx.globally_bound_vars with
|
||||
| None -> Var.Set.singleton v, (Bindlib.box_var v, m)
|
||||
| Some (TArrow (targs, tret), _) ->
|
||||
(* Here we eta-expand the argument to make sure function pointers are
|
||||
correctly casted as closures *)
|
||||
let args = Array.init (List.length targs) (fun _ -> Var.make "eta_arg") in
|
||||
let arg_vars =
|
||||
List.map2
|
||||
(fun v ty -> Expr.evar v (Expr.with_ty m ty))
|
||||
(Array.to_list args) targs
|
||||
in
|
||||
let e =
|
||||
Expr.eabs
|
||||
(Expr.bind args
|
||||
(Expr.eapp (Expr.rebox e) arg_vars (Expr.with_ty m tret)))
|
||||
targs m
|
||||
in
|
||||
let boxed =
|
||||
let ctx =
|
||||
(* We hide the type of the toplevel definition so that the function
|
||||
doesn't loop *)
|
||||
{
|
||||
ctx with
|
||||
globally_bound_vars =
|
||||
Var.Map.add v (TAny, Pos.no_pos) ctx.globally_bound_vars;
|
||||
}
|
||||
in
|
||||
Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e)
|
||||
in
|
||||
Bindlib.unbox boxed
|
||||
| Some _ -> Var.Set.empty, (Bindlib.box_var v, m))
|
||||
| EMatch { e; cases; name } ->
|
||||
let free_vars, new_e = (transform_closures_expr ctx) e in
|
||||
(* We do not close the clotures inside the arms of the match expression,
|
||||
@ -59,6 +87,11 @@ let rec transform_closures_expr :
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let new_free_vars, new_body = (transform_closures_expr ctx) body in
|
||||
let new_free_vars =
|
||||
Array.fold_left
|
||||
(fun acc v -> Var.Set.remove v acc)
|
||||
new_free_vars vars
|
||||
in
|
||||
let new_binder = Expr.bind vars new_body in
|
||||
( Var.Set.union free_vars
|
||||
(Var.Set.diff new_free_vars
|
||||
@ -75,6 +108,9 @@ let rec transform_closures_expr :
|
||||
(* let-binding, we should not close these *)
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let free_vars, new_body = (transform_closures_expr ctx) body in
|
||||
let free_vars =
|
||||
Array.fold_left (fun acc v -> Var.Set.remove v acc) free_vars vars
|
||||
in
|
||||
let new_binder = Expr.bind vars new_body in
|
||||
let free_vars, new_args =
|
||||
List.fold_right
|
||||
@ -195,11 +231,17 @@ let rec transform_closures_expr :
|
||||
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
|
||||
~f:(transform_closures_expr ctx)
|
||||
e
|
||||
| EApp { f = EVar v, _; _ } when Var.Set.mem v ctx.globally_bound_vars ->
|
||||
(* This corresponds to a scope call, which we don't want to transform*)
|
||||
Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union
|
||||
~f:(transform_closures_expr ctx)
|
||||
e
|
||||
| EApp { f = EVar v, f_m; args } when Var.Map.mem v ctx.globally_bound_vars ->
|
||||
(* This corresponds to a scope or toplevel function call, which we don't
|
||||
want to transform*)
|
||||
let free_vars, new_args =
|
||||
List.fold_right
|
||||
(fun arg (free_vars, new_args) ->
|
||||
let new_free_vars, new_arg = (transform_closures_expr ctx) arg in
|
||||
Var.Set.union free_vars new_free_vars, new_arg :: new_args)
|
||||
args (Var.Set.empty, [])
|
||||
in
|
||||
free_vars, Expr.eapp (Expr.evar v f_m) new_args m
|
||||
| EApp { f = e1; args } ->
|
||||
let free_vars, new_e1 = (transform_closures_expr ctx) e1 in
|
||||
let code_env_var = Var.make "code_and_env" in
|
||||
@ -286,12 +328,33 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
|
||||
let new_scope_body_expr =
|
||||
Bindlib.bind_var scope_input_var new_scope_lets
|
||||
in
|
||||
|
||||
( Var.Set.add var toplevel_vars,
|
||||
let ty =
|
||||
let pos = Mark.get (ScopeName.get_info name) in
|
||||
( TArrow
|
||||
( [TStruct body.scope_body_input_struct, pos],
|
||||
(TStruct body.scope_body_output_struct, pos) ),
|
||||
pos )
|
||||
in
|
||||
( Var.Map.add var ty toplevel_vars,
|
||||
Bindlib.box_apply
|
||||
(fun scope_body_expr ->
|
||||
ScopeDef (name, { body with scope_body_expr }))
|
||||
new_scope_body_expr )
|
||||
| Topdef (name, ty, (EAbs { binder; tys }, m)) ->
|
||||
let v, expr = Bindlib.unmbind binder in
|
||||
let ctx =
|
||||
{
|
||||
decl_ctx = p.decl_ctx;
|
||||
name_context = Mark.remove (TopdefName.get_info name);
|
||||
globally_bound_vars = toplevel_vars;
|
||||
}
|
||||
in
|
||||
let _free_vars, new_expr = transform_closures_expr ctx expr in
|
||||
let new_binder = Expr.bind v new_expr in
|
||||
( Var.Map.add var ty toplevel_vars,
|
||||
Bindlib.box_apply
|
||||
(fun e -> Topdef (name, ty, e))
|
||||
(Expr.Box.lift (Expr.eabs new_binder tys m)) )
|
||||
| Topdef (name, ty, expr) ->
|
||||
let ctx =
|
||||
{
|
||||
@ -301,12 +364,12 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
|
||||
}
|
||||
in
|
||||
let _free_vars, new_expr = transform_closures_expr ctx expr in
|
||||
( Var.Set.add var toplevel_vars,
|
||||
( Var.Map.add var ty toplevel_vars,
|
||||
Bindlib.box_apply
|
||||
(fun e -> Topdef (name, ty, e))
|
||||
(fun e -> Topdef (name, (TAny, Mark.get ty), e))
|
||||
(Expr.Box.lift new_expr) ))
|
||||
~varf:(fun v -> v)
|
||||
Var.Set.empty p.code_items
|
||||
Var.Map.empty p.code_items
|
||||
in
|
||||
(* Now we need to further tweak [decl_ctx] because some of the user-defined
|
||||
types can have closures in them and these closured might have changed type.
|
||||
@ -342,17 +405,26 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
|
||||
let replace_fun_typs t =
|
||||
if type_contains_arrow t then Mark.copy t TAny else t
|
||||
in
|
||||
{
|
||||
p.decl_ctx with
|
||||
ctx_structs =
|
||||
StructName.Map.map
|
||||
(StructField.Map.map replace_fun_typs)
|
||||
p.decl_ctx.ctx_structs;
|
||||
ctx_enums =
|
||||
EnumName.Map.map
|
||||
(EnumConstructor.Map.map replace_fun_typs)
|
||||
p.decl_ctx.ctx_enums;
|
||||
}
|
||||
let rec convert_ctx ctx =
|
||||
{
|
||||
ctx_struct_fields = ctx.ctx_struct_fields;
|
||||
ctx_modules = ModuleName.Map.map convert_ctx ctx.ctx_modules;
|
||||
ctx_structs =
|
||||
StructName.Map.map
|
||||
(StructField.Map.map replace_fun_typs)
|
||||
ctx.ctx_structs;
|
||||
ctx_enums =
|
||||
EnumName.Map.map
|
||||
(EnumConstructor.Map.map replace_fun_typs)
|
||||
ctx.ctx_enums;
|
||||
ctx_scopes = ctx.ctx_scopes;
|
||||
ctx_topdefs = ctx.ctx_topdefs;
|
||||
(* Toplevel definitions may not contain scope calls or take functions as
|
||||
arguments at the moment, which ensures that their interfaces aren't
|
||||
changed by the conversion *)
|
||||
}
|
||||
in
|
||||
convert_ctx p.decl_ctx
|
||||
in
|
||||
Bindlib.box_apply
|
||||
(fun new_code_items ->
|
||||
@ -528,13 +600,23 @@ let rec hoist_closures_code_item_list
|
||||
(fun scope_body_expr ->
|
||||
ScopeDef (name, { body with scope_body_expr }))
|
||||
new_scope_body_expr )
|
||||
| Topdef (name, ty, (EAbs { binder; tys }, m)) ->
|
||||
let v, expr = Bindlib.unmbind binder in
|
||||
let new_hoisted_closures, new_expr =
|
||||
hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr
|
||||
in
|
||||
let new_binder = Expr.bind v new_expr in
|
||||
( new_hoisted_closures,
|
||||
Bindlib.box_apply
|
||||
(fun e -> Topdef (name, ty, e))
|
||||
(Expr.Box.lift (Expr.eabs new_binder tys m)) )
|
||||
| Topdef (name, ty, expr) ->
|
||||
let new_hoisted_closures, new_expr =
|
||||
hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr
|
||||
in
|
||||
( new_hoisted_closures,
|
||||
Bindlib.box_apply
|
||||
(fun e -> Topdef (name, ty, e))
|
||||
(fun e -> Topdef (name, (TAny, Mark.get ty), e))
|
||||
(Expr.Box.lift new_expr) )
|
||||
in
|
||||
let next_code_items = hoist_closures_code_item_list next_code_items in
|
||||
|
@ -59,7 +59,13 @@ let scope S (S_in: S_in {x_in: collection integer}): S {y: integer} =
|
||||
ESome
|
||||
reduce
|
||||
(λ (potential_max_1: integer) (potential_max_2: integer) →
|
||||
if potential_max_1 < potential_max_2 then potential_max_1
|
||||
if
|
||||
(let potential_max : integer = potential_max_1 in
|
||||
potential_max)
|
||||
< let potential_max : integer = potential_max_2 in
|
||||
potential_max
|
||||
then
|
||||
potential_max_1
|
||||
else potential_max_2)
|
||||
-1
|
||||
x) ]
|
||||
|
Loading…
Reference in New Issue
Block a user