diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 041b08e0..87ad5709 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -360,12 +360,11 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = type 'm hoisted_closure = { name : 'm expr Var.t; ty : typ; - closure : 'm expr (* Starts with [EAbs]. *); + closure : (lcalc, 'm) boxed_gexpr (* Starts with [EAbs]. *); } let rec hoist_closures_expr : - type m. - string -> m expr -> m hoisted_closure Bindlib.box list * m expr boxed = + type m. string -> m expr -> m hoisted_closure list * m expr boxed = fun name_context e -> let m = Mark.get e in match Mark.remove e with @@ -447,14 +446,11 @@ let rec hoist_closures_expr : (* this is the closure we want to hoist*) let closure_var = Var.make ("closure_" ^ name_context) in ( [ - Bindlib.box_apply - (fun e -> - { - name = closure_var; - ty = TArrow (tys, (TAny, Expr.mark_pos m)), Expr.mark_pos m; - closure = e, m; - }) - (fst (Expr.box e)); + { + name = closure_var; + ty = TArrow (tys, (TAny, Expr.mark_pos m)), Expr.mark_pos m; + closure = Expr.rebox e; + }; ], Expr.make_var closure_var m ) | EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ @@ -492,62 +488,80 @@ let hoist_closures_scope_let name_context scope_body_expr = (Expr.Box.lift new_scope_let_expr) )) scope_body_expr +let rec hoist_closures_code_item_list + (code_items : (lcalc, 'm) gexpr code_item_list) : + (lcalc, 'm) gexpr code_item_list Bindlib.box = + match code_items with + | Nil -> Bindlib.box Nil + | Cons (code_item, next_code_items) -> + let code_item_var, next_code_items = Bindlib.unbind next_code_items in + let hoisted_closures, new_code_item = + match code_item with + | ScopeDef (name, body) -> + let scope_input_var, scope_body_expr = + Bindlib.unbind body.scope_body_expr + in + let new_hoisted_closures, new_scope_lets = + hoist_closures_scope_let + (fst (ScopeName.get_info name)) + scope_body_expr + in + let new_scope_body_expr = + Bindlib.bind_var scope_input_var new_scope_lets + in + ( new_hoisted_closures, + Bindlib.box_apply + (fun scope_body_expr -> + ScopeDef (name, { body with scope_body_expr })) + new_scope_body_expr ) + | Topdef (name, ty, expr) -> + let new_hoisted_closures, new_expr = + hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr + in + ( new_hoisted_closures, + Bindlib.box_apply + (fun e -> Topdef (name, ty, e)) + (Expr.Box.lift new_expr) ) + in + let next_code_items = hoist_closures_code_item_list next_code_items in + let next_code_items = + Bindlib.box_apply2 + (fun next_code_items new_code_item -> + Cons (new_code_item, next_code_items)) + (Bindlib.bind_var code_item_var next_code_items) + new_code_item + in + let next_code_items = + List.fold_left + (fun (next_code_items : (lcalc, 'm) gexpr code_item_list Bindlib.box) + (hoisted_closure : 'm hoisted_closure) -> + let next_code_items = + Bindlib.bind_var hoisted_closure.name next_code_items + in + let closure, closure_mark = hoisted_closure.closure in + Bindlib.box_apply2 + (fun next_code_items closure -> + Cons + ( Topdef + ( TopdefName.fresh + ( Bindlib.name_of hoisted_closure.name, + Expr.mark_pos closure_mark ), + hoisted_closure.ty, + (closure, closure_mark) ), + next_code_items )) + next_code_items closure) + next_code_items hoisted_closures + in + next_code_items + let hoist_closures_program (p : 'm program) : 'm program Bindlib.box = - let hoisted_closures, new_code_items = - Scope.fold_map - ~f:(fun hoisted_closures _var code_item -> - match code_item with - | ScopeDef (name, body) -> - let scope_input_var, scope_body_expr = - Bindlib.unbind body.scope_body_expr - in - let new_hoisted_closures, new_scope_lets = - hoist_closures_scope_let - (fst (ScopeName.get_info name)) - scope_body_expr - in - let new_scope_body_expr = - Bindlib.bind_var scope_input_var new_scope_lets - in - ( new_hoisted_closures @ hoisted_closures, - Bindlib.box_apply - (fun scope_body_expr -> - ScopeDef (name, { body with scope_body_expr })) - new_scope_body_expr ) - | Topdef (name, ty, expr) -> - let new_hoisted_closures, new_expr = - hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr - in - ( new_hoisted_closures @ hoisted_closures, - Bindlib.box_apply - (fun e -> Topdef (name, ty, e)) - (Expr.Box.lift new_expr) )) - ~varf:(fun v -> v) - [] p.code_items - in + let new_code_items = hoist_closures_code_item_list p.code_items in (*TODO: we need to insert the hoisted closures just before the scopes they belong to, because some of them call sub-scopes and putting them all at the beginning breaks dependency ordering. *) Bindlib.box_apply - (fun hoisted_closures -> - let new_code_items = - List.fold_left - (fun (new_code_items : _ gexpr code_item_list Bindlib.box) hc -> - let next = Bindlib.bind_var hc.name new_code_items in - Bindlib.box_apply - (fun next -> - Cons - ( Topdef - ( TopdefName.fresh - (Bindlib.name_of hc.name, Expr.pos hc.closure), - hc.ty, - hc.closure ), - next )) - next) - new_code_items hoisted_closures - in - { p with code_items = Bindlib.unbox new_code_items }) - (Bindlib.box_list hoisted_closures) + (fun new_code_items -> { p with code_items = new_code_items }) + new_code_items (** { 1 Closure conversion }*) diff --git a/tests/test_func/good/scope_call_func_struct_closure.catala_en b/tests/test_func/good/scope_call_func_struct_closure.catala_en index 06c3aba1..d88c0f66 100644 --- a/tests/test_func/good/scope_call_func_struct_closure.catala_en +++ b/tests/test_func/good/scope_call_func_struct_closure.catala_en @@ -41,7 +41,107 @@ This test case is tricky because it creates a situation where the type of the two closures in Foo.r are different even with optimizations enabled. ```catala-test-inline -$ catala Lcalc --avoid_exceptions -O --closure_conversion -s Foo +$ catala Lcalc --avoid_exceptions -O --closure_conversion +type eoption = | ENone of unit | ESome of any + +type Result = { + r: eoption ((closure_env, integer) → eoption integer * closure_env); + q: eoption integer; + } + +type SubFoo1 = { + x: eoption integer; + y: eoption ((closure_env, integer) → eoption integer * closure_env); + } + +type SubFoo2 = { + x1: eoption integer; + y: eoption ((closure_env, integer) → eoption integer * closure_env); + } + +type Foo = { z: eoption integer; } + +type SubFoo1_in = { x_in: eoption integer; } + +type SubFoo2_in = { x1_in: eoption integer; x2_in: eoption integer; } + +type Foo_in = { b_in: eoption bool; } + +let topval closure_y : (closure_env, integer) → eoption integer = + λ (env: closure_env) (z: integer) → + ESome + match + (match (from_closure_env env).0 with + | ENone _ → ENone _ + | ESome x_0 → ESome (x_0 + z)) + with + | ENone _ → raise NoValueProvided + | ESome y → y +let scope SubFoo1 + (SubFoo1_in: SubFoo1_in {x_in: eoption integer}) + : SubFoo1 { + x: eoption integer; + y: eoption ((closure_env, integer) → eoption integer * closure_env) + } + = + let get x : eoption integer = SubFoo1_in.x_in in + let set y : + eoption ((closure_env, integer) → eoption integer * closure_env) = + ESome (closure_y, to_closure_env (x)) + in + return { SubFoo1 x = x; y = y; } +let topval closure_y : (closure_env, integer) → eoption integer = + λ (env: closure_env) (z: integer) → + let env1 : (eoption integer * eoption integer) = from_closure_env env in + ESome + match + (match + (match env1.0 with + | ENone _ → ENone _ + | ESome x1_1 → + match env1.1 with + | ENone _ → ENone _ + | ESome x1_0 → ESome (x1_0 + x1_1)) + with + | ENone _ → ENone _ + | ESome y_0 → ESome (y_0 + z)) + with + | ENone _ → raise NoValueProvided + | ESome y → y +let scope SubFoo2 + (SubFoo2_in: SubFoo2_in {x1_in: eoption integer; x2_in: eoption integer}) + : SubFoo2 { + x1: eoption integer; + y: eoption ((closure_env, integer) → eoption integer * closure_env) + } + = + let get x1 : eoption integer = SubFoo2_in.x1_in in + let get x2 : eoption integer = SubFoo2_in.x2_in in + let set y : + eoption ((closure_env, integer) → eoption integer * closure_env) = + ESome (closure_y, to_closure_env (x2, x1)) + in + return { SubFoo2 x1 = x1; y = y; } +let topval closure_r : (closure_env, integer) → eoption integer = + λ (env: closure_env) (param0: integer) → + match (SubFoo2 { SubFoo2_in x1_in = ESome 10; x2_in = ESome 10; }).y with + | ENone _ → ENone _ + | ESome result → + let code_and_env : + ((closure_env, integer) → eoption integer * closure_env) = + result + in + code_and_env.0 code_and_env.1 param0 +let topval closure_r : (closure_env, integer) → eoption integer = + λ (env: closure_env) (param0: integer) → + match (SubFoo1 { SubFoo1_in x_in = ESome 10; }).y with + | ENone _ → ENone _ + | ESome result → + let code_and_env : + ((closure_env, integer) → eoption integer * closure_env) = + result + in + code_and_env.0 code_and_env.1 param0 let scope Foo (Foo_in: Foo_in {b_in: eoption bool}) : Foo {z: eoption integer} @@ -80,21 +180,7 @@ let scope Foo ESome { SubFoo1 x = ESome result_0; - y = - ESome - (λ (env: closure_env) (param0: integer) → - match - (SubFoo1 { SubFoo1_in x_in = ESome 10; }).y - with - | ENone _ → ENone _ - | ESome result → - let code_and_env : - ((closure_env, integer) → eoption integer * - closure_env) = - result - in - code_and_env.0 code_and_env.1 param0, - to_closure_env ()); + y = ESome (closure_r, to_closure_env ()); }) with | ENone _ → ENone _ @@ -116,26 +202,7 @@ let scope Foo ESome { SubFoo2 x1 = ESome result_0; - y = - ESome - (λ (env: closure_env) (param0: integer) → - match - (SubFoo2 - { SubFoo2_in - x1_in = ESome 10; - x2_in = ESome 10; - }). - y - with - | ENone _ → ENone _ - | ESome result → - let code_and_env : - ((closure_env, integer) → eoption integer * - closure_env) = - result - in - code_and_env.0 code_and_env.1 param0, - to_closure_env ()); + y = ESome (closure_r, to_closure_env ()); }) with | ENone _ → ENone _ @@ -168,6 +235,7 @@ let scope Foo | ESome z → z in return { Foo z = z; } + ``` ```catala-test-inline