From a20adc0055610f424bca93c1d26d831b619dcd8a Mon Sep 17 00:00:00 2001 From: Denis Merigoux Date: Sun, 18 Jun 2023 18:08:18 +0200 Subject: [PATCH] Closure hoisting (missing a bug on hardest case) --- Dockerfile | 2 + compiler/driver.ml | 7 + compiler/lcalc/closure_conversion.ml | 266 +++++++++++++----- compiler/shared_ast/print.ml | 7 +- compiler/shared_ast/typing.ml | 2 +- .../good/closure_conversion.catala_en | 30 +- tests/test_func/good/closure_return.catala_en | 32 ++- 7 files changed, 254 insertions(+), 92 deletions(-) diff --git a/Dockerfile b/Dockerfile index 7093f7ea..dec54020 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,6 +10,8 @@ RUN sudo apk add python3 RUN sudo ln -s /usr/bin/python3 /usr/bin/python RUN sudo apk add g++ RUN sudo apk add make +# We also need bash to build JaneStreet's base +RUN sudo apk add bash RUN mkdir catala WORKDIR catala diff --git a/compiler/driver.ml b/compiler/driver.ml index 56eff416..540dc0a8 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -443,6 +443,13 @@ let driver source_file (options : Cli.options) : int = Message.emit_debug "Performing closure conversion..."; let prgm = Lcalc.Closure_conversion.closure_conversion prgm in let prgm = Bindlib.unbox prgm in + (* let _output_file, with_output = get_output_format () in + with_output @@ fun fmt -> if Option.is_some options.ex_scope + then Format.fprintf fmt "%a\n" (Shared_ast.Print.scope + ~debug:options.debug prgm.decl_ctx) (scope_uid, + Shared_ast.Program.get_scope_body prgm scope_uid) else + Format.fprintf fmt "%a\n" (Shared_ast.Print.program + ~debug:options.debug) prgm; *) let prgm = if options.optimize then ( Message.emit_debug "Optimizing lambda calculus..."; diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 459fdadb..ff9b85fd 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -19,9 +19,6 @@ open Shared_ast open Ast module D = Dcalc.Ast -(** TODO: This version is not yet debugged and ought to be specialized when - Lcalc has more structure. *) - type 'm ctx = { decl_ctx : decl_ctx; name_context : string; @@ -30,65 +27,7 @@ type 'm ctx = { let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys -type 'm hoisted_closure = { - name : 'm expr Var.t; - closure : 'm expr (* Starts with [EAbs]. *); -} - -let rec hoist_context_free_closures : - type m. m ctx -> m expr -> m hoisted_closure list * m expr boxed = - fun ctx e -> - let m = Mark.get e in - match Mark.remove e with - | EMatch { e; cases; name } -> - let collected_closures, new_e = (hoist_context_free_closures ctx) e in - (* We do not close the closures inside the arms of the match expression, - since they get a special treatment at compilation to Scalc. *) - let collected_closures, new_cases = - EnumConstructor.Map.fold - (fun cons e1 (collected_closures, new_cases) -> - match Mark.remove e1 with - | EAbs { binder; tys } -> - let vars, body = Bindlib.unmbind binder in - let new_collected_closures, new_body = - (hoist_context_free_closures ctx) body - in - let new_binder = Expr.bind vars new_body in - ( collected_closures @ new_collected_closures, - EnumConstructor.Map.add cons - (Expr.eabs new_binder tys (Mark.get e1)) - new_cases ) - | _ -> failwith "should not happen") - cases - (collected_closures, EnumConstructor.Map.empty) - in - collected_closures, Expr.ematch new_e name new_cases m - | EApp { f = EAbs { binder; tys }, e1_pos; args } -> - (* let-binding, we should not close these *) - let vars, body = Bindlib.unmbind binder in - let collected_closures, new_body = (hoist_context_free_closures ctx) body in - let new_binder = Expr.bind vars new_body in - let collected_closures, new_args = - List.fold_right - (fun arg (collected_closures, new_args) -> - let new_collected_closures, new_arg = - (hoist_context_free_closures ctx) arg - in - collected_closures @ new_collected_closures, new_arg :: new_args) - args (collected_closures, []) - in - ( collected_closures, - Expr.eapp (Expr.eabs new_binder (tys_as_tanys tys) e1_pos) new_args m ) - | EAbs _ -> - (* this is the closure we want to hoist*) - let closure_var = Var.make ctx.name_context in - [{ name = closure_var; closure = e }], Expr.make_var closure_var m - | EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ - | EArray _ | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _ | ECatch _ - | EVar _ -> - Expr.map_gather ~acc:[] ~join:( @ ) ~f:(hoist_context_free_closures ctx) e - | _ -> . - [@@warning "-32"] +(** { 1 Transforming closures}*) (** Returns the expression with closed closures and the set of free variables inside this new expression. Implementation guided by @@ -294,7 +233,7 @@ let rec transform_closures_expr : (* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the type *) -let closure_conversion_scope_let ctx scope_body_expr = +let transform_closures_scope_let ctx scope_body_expr = Scope.fold_right_lets ~f:(fun scope_let var_next acc -> let _free_vars, new_scope_let_expr = @@ -323,7 +262,7 @@ let closure_conversion_scope_let ctx scope_body_expr = (Expr.Box.lift new_scope_let_expr)) scope_body_expr -let closure_conversion (p : 'm program) : 'm program Bindlib.box = +let transform_closures_program (p : 'm program) : 'm program Bindlib.box = let _, new_code_items = Scope.fold_map ~f:(fun toplevel_vars var code_item -> @@ -340,7 +279,7 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box = } in let new_scope_lets = - closure_conversion_scope_let ctx scope_body_expr + transform_closures_scope_let ctx scope_body_expr in let new_scope_body_expr = Bindlib.bind_var scope_input_var new_scope_lets @@ -415,3 +354,200 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box = (fun new_code_items -> { code_items = new_code_items; decl_ctx = new_decl_ctx }) new_code_items + +(** {1 Hoisting closures}*) + +type 'm hoisted_closure = { + name : 'm expr Var.t; + ty : typ; + closure : 'm expr (* Starts with [EAbs]. *); +} + +let rec hoist_closures_expr : + type m. + string -> m expr -> m hoisted_closure Bindlib.box list * m expr boxed = + fun name_context e -> + let m = Mark.get e in + match Mark.remove e with + | EMatch { e; cases; name } -> + let collected_closures, new_e = (hoist_closures_expr name_context) e in + (* We do not close the closures inside the arms of the match expression, + since they get a special treatment at compilation to Scalc. *) + let collected_closures, new_cases = + EnumConstructor.Map.fold + (fun cons e1 (collected_closures, new_cases) -> + match Mark.remove e1 with + | EAbs { binder; tys } -> + let vars, body = Bindlib.unmbind binder in + let new_collected_closures, new_body = + (hoist_closures_expr name_context) body + in + let new_binder = Expr.bind vars new_body in + ( collected_closures @ new_collected_closures, + EnumConstructor.Map.add cons + (Expr.eabs new_binder tys (Mark.get e1)) + new_cases ) + | _ -> failwith "should not happen") + cases + (collected_closures, EnumConstructor.Map.empty) + in + collected_closures, Expr.ematch new_e name new_cases m + | EApp { f = EAbs { binder; tys }, e1_pos; args } -> + (* let-binding, we should not close these *) + let vars, body = Bindlib.unmbind binder in + let collected_closures, new_body = + (hoist_closures_expr name_context) body + in + let new_binder = Expr.bind vars new_body in + let collected_closures, new_args = + List.fold_right + (fun arg (collected_closures, new_args) -> + let new_collected_closures, new_arg = + (hoist_closures_expr name_context) arg + in + collected_closures @ new_collected_closures, new_arg :: new_args) + args (collected_closures, []) + in + ( collected_closures, + Expr.eapp (Expr.eabs new_binder (tys_as_tanys tys) e1_pos) new_args m ) + | EApp + { + f = + (EOp { op = HandleDefaultOpt | Fold | Map | Filter | Reduce; _ }, _) + as f; + args; + } -> + (* Special case for some operators: its arguments closures thunks because if + you want to extract it as a function you need these closures to preserve + evaluation order, but backends that don't support closures will simply + extract these operators in a inlined way and skip the thunks. *) + let collected_closures, new_args = + List.fold_right + (fun (arg : (lcalc, m) gexpr) (collected_closures, new_args) -> + let m_arg = Mark.get arg in + match Mark.remove arg with + | EAbs { binder; tys } -> + let vars, arg = Bindlib.unmbind binder in + let new_collected_closures, new_arg = + (hoist_closures_expr name_context) arg + in + let new_arg = + Expr.make_abs vars new_arg tys (Expr.mark_pos m_arg) + in + new_collected_closures @ collected_closures, new_arg :: new_args + | _ -> + let new_collected_closures, new_arg = + hoist_closures_expr name_context arg + in + new_collected_closures @ collected_closures, new_arg :: new_args) + args ([], []) + in + collected_closures, Expr.eapp (Expr.box f) new_args (Mark.get e) + | EAbs { tys; _ } -> + (* this is the closure we want to hoist*) + let closure_var = Var.make ("closure_" ^ name_context) in + ( [ + Bindlib.box_apply + (fun e -> + { + name = closure_var; + ty = TArrow (tys, (TAny, Expr.mark_pos m)), Expr.mark_pos m; + closure = e, m; + }) + (fst (Expr.box e)); + ], + Expr.make_var closure_var m ) + | EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ + | EArray _ | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _ | ECatch _ + | EVar _ -> + Expr.map_gather ~acc:[] ~join:( @ ) ~f:(hoist_closures_expr name_context) e + | _ -> . + [@@warning "-32"] + +(* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the + type *) +let hoist_closures_scope_let name_context scope_body_expr = + Scope.fold_right_lets + ~f:(fun scope_let var_next (hoisted_closures, next_scope_lets) -> + let new_hoisted_closures, new_scope_let_expr = + (hoist_closures_expr (Bindlib.name_of var_next)) + scope_let.scope_let_expr + in + ( new_hoisted_closures @ hoisted_closures, + Bindlib.box_apply2 + (fun scope_let_next scope_let_expr -> + ScopeLet { scope_let with scope_let_next; scope_let_expr }) + (Bindlib.bind_var var_next next_scope_lets) + (Expr.Box.lift new_scope_let_expr) )) + ~init:(fun res -> + let hoisted_closures, new_scope_let_expr = + (hoist_closures_expr name_context) res + in + (* INVARIANT here: the result expr of a scope is simply a struct + containing all output variables so nothing should be converted here, so + no need to take into account free variables. *) + ( hoisted_closures, + Bindlib.box_apply + (fun res -> Result res) + (Expr.Box.lift new_scope_let_expr) )) + scope_body_expr + +let hoist_closures_program (p : 'm program) : 'm program Bindlib.box = + let hoisted_closures, new_code_items = + Scope.fold_map + ~f:(fun hoisted_closures _var code_item -> + match code_item with + | ScopeDef (name, body) -> + let scope_input_var, scope_body_expr = + Bindlib.unbind body.scope_body_expr + in + let new_hoisted_closures, new_scope_lets = + hoist_closures_scope_let + (fst (ScopeName.get_info name)) + scope_body_expr + in + let new_scope_body_expr = + Bindlib.bind_var scope_input_var new_scope_lets + in + ( new_hoisted_closures @ hoisted_closures, + Bindlib.box_apply + (fun scope_body_expr -> + ScopeDef (name, { body with scope_body_expr })) + new_scope_body_expr ) + | Topdef (name, ty, expr) -> + let new_hoisted_closures, new_expr = + hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr + in + ( new_hoisted_closures @ hoisted_closures, + Bindlib.box_apply + (fun e -> Topdef (name, ty, e)) + (Expr.Box.lift new_expr) )) + ~varf:(fun v -> v) + [] p.code_items + in + Bindlib.box_apply + (fun hoisted_closures -> + let new_code_items = + List.fold_left + (fun (new_code_items : _ gexpr code_item_list Bindlib.box) hc -> + let next = Bindlib.bind_var hc.name new_code_items in + Bindlib.box_apply + (fun next -> + Cons + ( Topdef + ( TopdefName.fresh + (Bindlib.name_of hc.name, Expr.pos hc.closure), + hc.ty, + hc.closure ), + next )) + next) + new_code_items hoisted_closures + in + { p with code_items = Bindlib.unbox new_code_items }) + (Bindlib.box_list hoisted_closures) + +(** { 1 Closure conversion }*) + +let closure_conversion (p : 'm program) : 'm program Bindlib.box = + let new_p = transform_closures_program p in + hoist_closures_program (Bindlib.unbox new_p) diff --git a/compiler/shared_ast/print.ml b/compiler/shared_ast/print.ml index 7660db9c..b137c49e 100644 --- a/compiler/shared_ast/print.ml +++ b/compiler/shared_ast/print.ml @@ -855,16 +855,15 @@ let code_item ?(debug = false) decl_ctx fmt c = match c with | ScopeDef (n, b) -> scope ~debug decl_ctx fmt (n, b) | Topdef (n, ty, e) -> - Format.fprintf fmt "@[%a %a %a %a %a %a @]" keyword "let topval" - TopdefName.format_t n op_style ":" (typ decl_ctx) ty op_style "=" - (expr ~debug ()) e + Format.fprintf fmt "@[@[%a@ %a@ %a@ %a@ %a@]@ %a@]" keyword + "let topval" TopdefName.format_t n op_style ":" (typ decl_ctx) ty op_style + "=" (expr ~debug ()) e let rec code_item_list ?(debug = false) decl_ctx fmt c = match c with | Nil -> () | Cons (c, b) -> let _x, cl = Bindlib.unbind b in - Format.fprintf fmt "%a @.%a" (code_item ~debug decl_ctx) c diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index a1e9e677..8e89ad8f 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -951,7 +951,7 @@ let rec scopes ~leave_unresolved ctx env = function let e' = Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e' in ( Env.add var uf env, Bindlib.box_apply - (fun e -> A.Topdef (name, typ, e)) + (fun e -> A.Topdef (name, Expr.ty e', e)) (Expr.Box.lift e') ) in let next', env = scopes ~leave_unresolved ctx env next in diff --git a/tests/test_func/good/closure_conversion.catala_en b/tests/test_func/good/closure_conversion.catala_en index 5fff50b4..9b4cbb26 100644 --- a/tests/test_func/good/closure_conversion.catala_en +++ b/tests/test_func/good/closure_conversion.catala_en @@ -12,21 +12,28 @@ scope S: ``` ```catala-test-inline -$ catala Lcalc -s S --avoid_exceptions -O --closure_conversion +$ catala Lcalc --avoid_exceptions -O --closure_conversion +type eoption = | ENone of unit | ESome of any + +type S = { z: eoption integer; } + +type S_in = { x_in: eoption bool; } + +let topval closure_f : (closure_env, integer) → eoption integer = + λ (env: closure_env) (y: integer) → + ESome + match + (match (from_closure_env env).0 with + | ENone _ → ENone _ + | ESome x → if x then ESome y else ESome - y) + with + | ENone _ → raise NoValueProvided + | ESome f → f let scope S (S_in: S_in {x_in: eoption bool}): S {z: eoption integer} = let get x : eoption bool = S_in.x_in in let set f : eoption ((closure_env, integer) → eoption integer * closure_env) = - ESome - (λ (env: closure_env) (y: integer) → - ESome - match - (match (from_closure_env env).0 with - | ENone _ → ENone _ - | ESome x → if x then ESome y else ESome - y) - with - | ENone _ → raise NoValueProvided - | ESome f → f, to_closure_env (x)) + ESome (closure_f, to_closure_env (x)) in let set z : eoption integer = ESome @@ -44,4 +51,5 @@ let scope S (S_in: S_in {x_in: eoption bool}): S {z: eoption integer} = | ESome z → z in return { S z = z; } + ``` diff --git a/tests/test_func/good/closure_return.catala_en b/tests/test_func/good/closure_return.catala_en index ae554393..fea06f77 100644 --- a/tests/test_func/good/closure_return.catala_en +++ b/tests/test_func/good/closure_return.catala_en @@ -10,7 +10,25 @@ scope S: ``` ```catala-test-inline -$ catala Lcalc -s S --avoid_exceptions -O --closure_conversion +$ catala Lcalc --avoid_exceptions -O --closure_conversion +type eoption = | ENone of unit | ESome of any + +type S = { + f: eoption ((closure_env, integer) → eoption integer * closure_env); + } + +type S_in = { x_in: eoption bool; } + +let topval closure_f : (closure_env, integer) → eoption integer = + λ (env: closure_env) (y: integer) → + ESome + match + (match (from_closure_env env).0 with + | ENone _ → ENone _ + | ESome x → if x then ESome y else ESome - y) + with + | ENone _ → raise NoValueProvided + | ESome f → f let scope S (S_in: S_in {x_in: eoption bool}) : S {f: eoption ((closure_env, integer) → eoption integer * closure_env)} @@ -18,16 +36,8 @@ let scope S let get x : eoption bool = S_in.x_in in let set f : eoption ((closure_env, integer) → eoption integer * closure_env) = - ESome - (λ (env: closure_env) (y: integer) → - ESome - match - (match (from_closure_env env).0 with - | ENone _ → ENone _ - | ESome x → if x then ESome y else ESome - y) - with - | ENone _ → raise NoValueProvided - | ESome f → f, to_closure_env (x)) + ESome (closure_f, to_closure_env (x)) in return { S f = f; } + ```