diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index d7ab4666..f1cc96ae 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -22,7 +22,7 @@ module D = Dcalc.Ast type 'm ctx = { decl_ctx : decl_ctx; name_context : string; - globally_bound_vars : 'm expr Var.Set.t; + globally_bound_vars : ('m expr, typ) Var.Map.t; } let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys @@ -44,10 +44,38 @@ let rec transform_closures_expr : Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:(transform_closures_expr ctx) e - | EVar v -> - ( (if Var.Set.mem v ctx.globally_bound_vars then Var.Set.empty - else Var.Set.singleton v), - (Bindlib.box_var v, m) ) + | EVar v -> ( + match Var.Map.find_opt v ctx.globally_bound_vars with + | None -> Var.Set.singleton v, (Bindlib.box_var v, m) + | Some (TArrow (targs, tret), _) -> + (* Here we eta-expand the argument to make sure function pointers are + correctly casted as closures *) + let args = Array.init (List.length targs) (fun _ -> Var.make "eta_arg") in + let arg_vars = + List.map2 + (fun v ty -> Expr.evar v (Expr.with_ty m ty)) + (Array.to_list args) targs + in + let e = + Expr.eabs + (Expr.bind args + (Expr.eapp (Expr.rebox e) arg_vars (Expr.with_ty m tret))) + targs m + in + let boxed = + let ctx = + (* We hide the type of the toplevel definition so that the function + doesn't loop *) + { + ctx with + globally_bound_vars = + Var.Map.add v (TAny, Pos.no_pos) ctx.globally_bound_vars; + } + in + Bindlib.box_apply (transform_closures_expr ctx) (Expr.Box.lift e) + in + Bindlib.unbox boxed + | Some _ -> Var.Set.empty, (Bindlib.box_var v, m)) | EMatch { e; cases; name } -> let free_vars, new_e = (transform_closures_expr ctx) e in (* We do not close the clotures inside the arms of the match expression, @@ -59,6 +87,11 @@ let rec transform_closures_expr : | EAbs { binder; tys } -> let vars, body = Bindlib.unmbind binder in let new_free_vars, new_body = (transform_closures_expr ctx) body in + let new_free_vars = + Array.fold_left + (fun acc v -> Var.Set.remove v acc) + new_free_vars vars + in let new_binder = Expr.bind vars new_body in ( Var.Set.union free_vars (Var.Set.diff new_free_vars @@ -75,6 +108,9 @@ let rec transform_closures_expr : (* let-binding, we should not close these *) let vars, body = Bindlib.unmbind binder in let free_vars, new_body = (transform_closures_expr ctx) body in + let free_vars = + Array.fold_left (fun acc v -> Var.Set.remove v acc) free_vars vars + in let new_binder = Expr.bind vars new_body in let free_vars, new_args = List.fold_right @@ -195,11 +231,17 @@ let rec transform_closures_expr : Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:(transform_closures_expr ctx) e - | EApp { f = EVar v, _; _ } when Var.Set.mem v ctx.globally_bound_vars -> - (* This corresponds to a scope call, which we don't want to transform*) - Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union - ~f:(transform_closures_expr ctx) - e + | EApp { f = EVar v, f_m; args } when Var.Map.mem v ctx.globally_bound_vars -> + (* This corresponds to a scope or toplevel function call, which we don't + want to transform*) + let free_vars, new_args = + List.fold_right + (fun arg (free_vars, new_args) -> + let new_free_vars, new_arg = (transform_closures_expr ctx) arg in + Var.Set.union free_vars new_free_vars, new_arg :: new_args) + args (Var.Set.empty, []) + in + free_vars, Expr.eapp (Expr.evar v f_m) new_args m | EApp { f = e1; args } -> let free_vars, new_e1 = (transform_closures_expr ctx) e1 in let code_env_var = Var.make "code_and_env" in @@ -286,12 +328,33 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = let new_scope_body_expr = Bindlib.bind_var scope_input_var new_scope_lets in - - ( Var.Set.add var toplevel_vars, + let ty = + let pos = Mark.get (ScopeName.get_info name) in + ( TArrow + ( [TStruct body.scope_body_input_struct, pos], + (TStruct body.scope_body_output_struct, pos) ), + pos ) + in + ( Var.Map.add var ty toplevel_vars, Bindlib.box_apply (fun scope_body_expr -> ScopeDef (name, { body with scope_body_expr })) new_scope_body_expr ) + | Topdef (name, ty, (EAbs { binder; tys }, m)) -> + let v, expr = Bindlib.unmbind binder in + let ctx = + { + decl_ctx = p.decl_ctx; + name_context = Mark.remove (TopdefName.get_info name); + globally_bound_vars = toplevel_vars; + } + in + let _free_vars, new_expr = transform_closures_expr ctx expr in + let new_binder = Expr.bind v new_expr in + ( Var.Map.add var ty toplevel_vars, + Bindlib.box_apply + (fun e -> Topdef (name, ty, e)) + (Expr.Box.lift (Expr.eabs new_binder tys m)) ) | Topdef (name, ty, expr) -> let ctx = { @@ -301,12 +364,12 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = } in let _free_vars, new_expr = transform_closures_expr ctx expr in - ( Var.Set.add var toplevel_vars, + ( Var.Map.add var ty toplevel_vars, Bindlib.box_apply - (fun e -> Topdef (name, ty, e)) + (fun e -> Topdef (name, (TAny, Mark.get ty), e)) (Expr.Box.lift new_expr) )) ~varf:(fun v -> v) - Var.Set.empty p.code_items + Var.Map.empty p.code_items in (* Now we need to further tweak [decl_ctx] because some of the user-defined types can have closures in them and these closured might have changed type. @@ -342,17 +405,26 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = let replace_fun_typs t = if type_contains_arrow t then Mark.copy t TAny else t in - { - p.decl_ctx with - ctx_structs = - StructName.Map.map - (StructField.Map.map replace_fun_typs) - p.decl_ctx.ctx_structs; - ctx_enums = - EnumName.Map.map - (EnumConstructor.Map.map replace_fun_typs) - p.decl_ctx.ctx_enums; - } + let rec convert_ctx ctx = + { + ctx_struct_fields = ctx.ctx_struct_fields; + ctx_modules = ModuleName.Map.map convert_ctx ctx.ctx_modules; + ctx_structs = + StructName.Map.map + (StructField.Map.map replace_fun_typs) + ctx.ctx_structs; + ctx_enums = + EnumName.Map.map + (EnumConstructor.Map.map replace_fun_typs) + ctx.ctx_enums; + ctx_scopes = ctx.ctx_scopes; + ctx_topdefs = ctx.ctx_topdefs; + (* Toplevel definitions may not contain scope calls or take functions as + arguments at the moment, which ensures that their interfaces aren't + changed by the conversion *) + } + in + convert_ctx p.decl_ctx in Bindlib.box_apply (fun new_code_items -> @@ -528,13 +600,23 @@ let rec hoist_closures_code_item_list (fun scope_body_expr -> ScopeDef (name, { body with scope_body_expr })) new_scope_body_expr ) + | Topdef (name, ty, (EAbs { binder; tys }, m)) -> + let v, expr = Bindlib.unmbind binder in + let new_hoisted_closures, new_expr = + hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr + in + let new_binder = Expr.bind v new_expr in + ( new_hoisted_closures, + Bindlib.box_apply + (fun e -> Topdef (name, ty, e)) + (Expr.Box.lift (Expr.eabs new_binder tys m)) ) | Topdef (name, ty, expr) -> let new_hoisted_closures, new_expr = hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr in ( new_hoisted_closures, Bindlib.box_apply - (fun e -> Topdef (name, ty, e)) + (fun e -> Topdef (name, (TAny, Mark.get ty), e)) (Expr.Box.lift new_expr) ) in let next_code_items = hoist_closures_code_item_list next_code_items in diff --git a/tests/test_func/good/closure_conversion_reduce.catala_en b/tests/test_func/good/closure_conversion_reduce.catala_en index 7d19c495..1f3d9277 100644 --- a/tests/test_func/good/closure_conversion_reduce.catala_en +++ b/tests/test_func/good/closure_conversion_reduce.catala_en @@ -59,7 +59,13 @@ let scope S (S_in: S_in {x_in: collection integer}): S {y: integer} = ESome reduce (λ (potential_max_1: integer) (potential_max_2: integer) → - if potential_max_1 < potential_max_2 then potential_max_1 + if + (let potential_max : integer = potential_max_1 in + potential_max) + < let potential_max : integer = potential_max_2 in + potential_max + then + potential_max_1 else potential_max_2) -1 x) ]