Closure hoisting (missing a bug on hardest case)

This commit is contained in:
Denis Merigoux 2023-06-18 18:08:18 +02:00
parent 2c45ca1599
commit a20adc0055
7 changed files with 254 additions and 92 deletions

View File

@ -10,6 +10,8 @@ RUN sudo apk add python3
RUN sudo ln -s /usr/bin/python3 /usr/bin/python
RUN sudo apk add g++
RUN sudo apk add make
# We also need bash to build JaneStreet's base
RUN sudo apk add bash
RUN mkdir catala
WORKDIR catala

View File

@ -443,6 +443,13 @@ let driver source_file (options : Cli.options) : int =
Message.emit_debug "Performing closure conversion...";
let prgm = Lcalc.Closure_conversion.closure_conversion prgm in
let prgm = Bindlib.unbox prgm in
(* let _output_file, with_output = get_output_format () in
with_output @@ fun fmt -> if Option.is_some options.ex_scope
then Format.fprintf fmt "%a\n" (Shared_ast.Print.scope
~debug:options.debug prgm.decl_ctx) (scope_uid,
Shared_ast.Program.get_scope_body prgm scope_uid) else
Format.fprintf fmt "%a\n" (Shared_ast.Print.program
~debug:options.debug) prgm; *)
let prgm =
if options.optimize then (
Message.emit_debug "Optimizing lambda calculus...";

View File

@ -19,9 +19,6 @@ open Shared_ast
open Ast
module D = Dcalc.Ast
(** TODO: This version is not yet debugged and ought to be specialized when
Lcalc has more structure. *)
type 'm ctx = {
decl_ctx : decl_ctx;
name_context : string;
@ -30,65 +27,7 @@ type 'm ctx = {
let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys
type 'm hoisted_closure = {
name : 'm expr Var.t;
closure : 'm expr (* Starts with [EAbs]. *);
}
let rec hoist_context_free_closures :
type m. m ctx -> m expr -> m hoisted_closure list * m expr boxed =
fun ctx e ->
let m = Mark.get e in
match Mark.remove e with
| EMatch { e; cases; name } ->
let collected_closures, new_e = (hoist_context_free_closures ctx) e in
(* We do not close the closures inside the arms of the match expression,
since they get a special treatment at compilation to Scalc. *)
let collected_closures, new_cases =
EnumConstructor.Map.fold
(fun cons e1 (collected_closures, new_cases) ->
match Mark.remove e1 with
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let new_collected_closures, new_body =
(hoist_context_free_closures ctx) body
in
let new_binder = Expr.bind vars new_body in
( collected_closures @ new_collected_closures,
EnumConstructor.Map.add cons
(Expr.eabs new_binder tys (Mark.get e1))
new_cases )
| _ -> failwith "should not happen")
cases
(collected_closures, EnumConstructor.Map.empty)
in
collected_closures, Expr.ematch new_e name new_cases m
| EApp { f = EAbs { binder; tys }, e1_pos; args } ->
(* let-binding, we should not close these *)
let vars, body = Bindlib.unmbind binder in
let collected_closures, new_body = (hoist_context_free_closures ctx) body in
let new_binder = Expr.bind vars new_body in
let collected_closures, new_args =
List.fold_right
(fun arg (collected_closures, new_args) ->
let new_collected_closures, new_arg =
(hoist_context_free_closures ctx) arg
in
collected_closures @ new_collected_closures, new_arg :: new_args)
args (collected_closures, [])
in
( collected_closures,
Expr.eapp (Expr.eabs new_binder (tys_as_tanys tys) e1_pos) new_args m )
| EAbs _ ->
(* this is the closure we want to hoist*)
let closure_var = Var.make ctx.name_context in
[{ name = closure_var; closure = e }], Expr.make_var closure_var m
| EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
| EArray _ | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _ | ECatch _
| EVar _ ->
Expr.map_gather ~acc:[] ~join:( @ ) ~f:(hoist_context_free_closures ctx) e
| _ -> .
[@@warning "-32"]
(** { 1 Transforming closures}*)
(** Returns the expression with closed closures and the set of free variables
inside this new expression. Implementation guided by
@ -294,7 +233,7 @@ let rec transform_closures_expr :
(* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the
type *)
let closure_conversion_scope_let ctx scope_body_expr =
let transform_closures_scope_let ctx scope_body_expr =
Scope.fold_right_lets
~f:(fun scope_let var_next acc ->
let _free_vars, new_scope_let_expr =
@ -323,7 +262,7 @@ let closure_conversion_scope_let ctx scope_body_expr =
(Expr.Box.lift new_scope_let_expr))
scope_body_expr
let closure_conversion (p : 'm program) : 'm program Bindlib.box =
let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
let _, new_code_items =
Scope.fold_map
~f:(fun toplevel_vars var code_item ->
@ -340,7 +279,7 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
}
in
let new_scope_lets =
closure_conversion_scope_let ctx scope_body_expr
transform_closures_scope_let ctx scope_body_expr
in
let new_scope_body_expr =
Bindlib.bind_var scope_input_var new_scope_lets
@ -415,3 +354,200 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
(fun new_code_items ->
{ code_items = new_code_items; decl_ctx = new_decl_ctx })
new_code_items
(** {1 Hoisting closures}*)
type 'm hoisted_closure = {
name : 'm expr Var.t;
ty : typ;
closure : 'm expr (* Starts with [EAbs]. *);
}
let rec hoist_closures_expr :
type m.
string -> m expr -> m hoisted_closure Bindlib.box list * m expr boxed =
fun name_context e ->
let m = Mark.get e in
match Mark.remove e with
| EMatch { e; cases; name } ->
let collected_closures, new_e = (hoist_closures_expr name_context) e in
(* We do not close the closures inside the arms of the match expression,
since they get a special treatment at compilation to Scalc. *)
let collected_closures, new_cases =
EnumConstructor.Map.fold
(fun cons e1 (collected_closures, new_cases) ->
match Mark.remove e1 with
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let new_collected_closures, new_body =
(hoist_closures_expr name_context) body
in
let new_binder = Expr.bind vars new_body in
( collected_closures @ new_collected_closures,
EnumConstructor.Map.add cons
(Expr.eabs new_binder tys (Mark.get e1))
new_cases )
| _ -> failwith "should not happen")
cases
(collected_closures, EnumConstructor.Map.empty)
in
collected_closures, Expr.ematch new_e name new_cases m
| EApp { f = EAbs { binder; tys }, e1_pos; args } ->
(* let-binding, we should not close these *)
let vars, body = Bindlib.unmbind binder in
let collected_closures, new_body =
(hoist_closures_expr name_context) body
in
let new_binder = Expr.bind vars new_body in
let collected_closures, new_args =
List.fold_right
(fun arg (collected_closures, new_args) ->
let new_collected_closures, new_arg =
(hoist_closures_expr name_context) arg
in
collected_closures @ new_collected_closures, new_arg :: new_args)
args (collected_closures, [])
in
( collected_closures,
Expr.eapp (Expr.eabs new_binder (tys_as_tanys tys) e1_pos) new_args m )
| EApp
{
f =
(EOp { op = HandleDefaultOpt | Fold | Map | Filter | Reduce; _ }, _)
as f;
args;
} ->
(* Special case for some operators: its arguments closures thunks because if
you want to extract it as a function you need these closures to preserve
evaluation order, but backends that don't support closures will simply
extract these operators in a inlined way and skip the thunks. *)
let collected_closures, new_args =
List.fold_right
(fun (arg : (lcalc, m) gexpr) (collected_closures, new_args) ->
let m_arg = Mark.get arg in
match Mark.remove arg with
| EAbs { binder; tys } ->
let vars, arg = Bindlib.unmbind binder in
let new_collected_closures, new_arg =
(hoist_closures_expr name_context) arg
in
let new_arg =
Expr.make_abs vars new_arg tys (Expr.mark_pos m_arg)
in
new_collected_closures @ collected_closures, new_arg :: new_args
| _ ->
let new_collected_closures, new_arg =
hoist_closures_expr name_context arg
in
new_collected_closures @ collected_closures, new_arg :: new_args)
args ([], [])
in
collected_closures, Expr.eapp (Expr.box f) new_args (Mark.get e)
| EAbs { tys; _ } ->
(* this is the closure we want to hoist*)
let closure_var = Var.make ("closure_" ^ name_context) in
( [
Bindlib.box_apply
(fun e ->
{
name = closure_var;
ty = TArrow (tys, (TAny, Expr.mark_pos m)), Expr.mark_pos m;
closure = e, m;
})
(fst (Expr.box e));
],
Expr.make_var closure_var m )
| EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
| EArray _ | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _ | ECatch _
| EVar _ ->
Expr.map_gather ~acc:[] ~join:( @ ) ~f:(hoist_closures_expr name_context) e
| _ -> .
[@@warning "-32"]
(* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the
type *)
let hoist_closures_scope_let name_context scope_body_expr =
Scope.fold_right_lets
~f:(fun scope_let var_next (hoisted_closures, next_scope_lets) ->
let new_hoisted_closures, new_scope_let_expr =
(hoist_closures_expr (Bindlib.name_of var_next))
scope_let.scope_let_expr
in
( new_hoisted_closures @ hoisted_closures,
Bindlib.box_apply2
(fun scope_let_next scope_let_expr ->
ScopeLet { scope_let with scope_let_next; scope_let_expr })
(Bindlib.bind_var var_next next_scope_lets)
(Expr.Box.lift new_scope_let_expr) ))
~init:(fun res ->
let hoisted_closures, new_scope_let_expr =
(hoist_closures_expr name_context) res
in
(* INVARIANT here: the result expr of a scope is simply a struct
containing all output variables so nothing should be converted here, so
no need to take into account free variables. *)
( hoisted_closures,
Bindlib.box_apply
(fun res -> Result res)
(Expr.Box.lift new_scope_let_expr) ))
scope_body_expr
let hoist_closures_program (p : 'm program) : 'm program Bindlib.box =
let hoisted_closures, new_code_items =
Scope.fold_map
~f:(fun hoisted_closures _var code_item ->
match code_item with
| ScopeDef (name, body) ->
let scope_input_var, scope_body_expr =
Bindlib.unbind body.scope_body_expr
in
let new_hoisted_closures, new_scope_lets =
hoist_closures_scope_let
(fst (ScopeName.get_info name))
scope_body_expr
in
let new_scope_body_expr =
Bindlib.bind_var scope_input_var new_scope_lets
in
( new_hoisted_closures @ hoisted_closures,
Bindlib.box_apply
(fun scope_body_expr ->
ScopeDef (name, { body with scope_body_expr }))
new_scope_body_expr )
| Topdef (name, ty, expr) ->
let new_hoisted_closures, new_expr =
hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr
in
( new_hoisted_closures @ hoisted_closures,
Bindlib.box_apply
(fun e -> Topdef (name, ty, e))
(Expr.Box.lift new_expr) ))
~varf:(fun v -> v)
[] p.code_items
in
Bindlib.box_apply
(fun hoisted_closures ->
let new_code_items =
List.fold_left
(fun (new_code_items : _ gexpr code_item_list Bindlib.box) hc ->
let next = Bindlib.bind_var hc.name new_code_items in
Bindlib.box_apply
(fun next ->
Cons
( Topdef
( TopdefName.fresh
(Bindlib.name_of hc.name, Expr.pos hc.closure),
hc.ty,
hc.closure ),
next ))
next)
new_code_items hoisted_closures
in
{ p with code_items = Bindlib.unbox new_code_items })
(Bindlib.box_list hoisted_closures)
(** { 1 Closure conversion }*)
let closure_conversion (p : 'm program) : 'm program Bindlib.box =
let new_p = transform_closures_program p in
hoist_closures_program (Bindlib.unbox new_p)

View File

@ -855,16 +855,15 @@ let code_item ?(debug = false) decl_ctx fmt c =
match c with
| ScopeDef (n, b) -> scope ~debug decl_ctx fmt (n, b)
| Topdef (n, ty, e) ->
Format.fprintf fmt "@[%a %a %a %a %a %a @]" keyword "let topval"
TopdefName.format_t n op_style ":" (typ decl_ctx) ty op_style "="
(expr ~debug ()) e
Format.fprintf fmt "@[<v 2>@[<hov 2>%a@ %a@ %a@ %a@ %a@]@ %a@]" keyword
"let topval" TopdefName.format_t n op_style ":" (typ decl_ctx) ty op_style
"=" (expr ~debug ()) e
let rec code_item_list ?(debug = false) decl_ctx fmt c =
match c with
| Nil -> ()
| Cons (c, b) ->
let _x, cl = Bindlib.unbind b in
Format.fprintf fmt "%a @.%a"
(code_item ~debug decl_ctx)
c

View File

@ -951,7 +951,7 @@ let rec scopes ~leave_unresolved ctx env = function
let e' = Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e' in
( Env.add var uf env,
Bindlib.box_apply
(fun e -> A.Topdef (name, typ, e))
(fun e -> A.Topdef (name, Expr.ty e', e))
(Expr.Box.lift e') )
in
let next', env = scopes ~leave_unresolved ctx env next in

View File

@ -12,21 +12,28 @@ scope S:
```
```catala-test-inline
$ catala Lcalc -s S --avoid_exceptions -O --closure_conversion
$ catala Lcalc --avoid_exceptions -O --closure_conversion
type eoption = | ENone of unit | ESome of any
type S = { z: eoption integer; }
type S_in = { x_in: eoption bool; }
let topval closure_f : (closure_env, integer) → eoption integer =
λ (env: closure_env) (y: integer) →
ESome
match
(match (from_closure_env env).0 with
| ENone _ → ENone _
| ESome x → if x then ESome y else ESome - y)
with
| ENone _ → raise NoValueProvided
| ESome f → f
let scope S (S_in: S_in {x_in: eoption bool}): S {z: eoption integer} =
let get x : eoption bool = S_in.x_in in
let set f :
eoption ((closure_env, integer) → eoption integer * closure_env) =
ESome
(λ (env: closure_env) (y: integer) →
ESome
match
(match (from_closure_env env).0 with
| ENone _ → ENone _
| ESome x → if x then ESome y else ESome - y)
with
| ENone _ → raise NoValueProvided
| ESome f → f, to_closure_env (x))
ESome (closure_f, to_closure_env (x))
in
let set z : eoption integer =
ESome
@ -44,4 +51,5 @@ let scope S (S_in: S_in {x_in: eoption bool}): S {z: eoption integer} =
| ESome z → z
in
return { S z = z; }
```

View File

@ -10,7 +10,25 @@ scope S:
```
```catala-test-inline
$ catala Lcalc -s S --avoid_exceptions -O --closure_conversion
$ catala Lcalc --avoid_exceptions -O --closure_conversion
type eoption = | ENone of unit | ESome of any
type S = {
f: eoption ((closure_env, integer) → eoption integer * closure_env);
}
type S_in = { x_in: eoption bool; }
let topval closure_f : (closure_env, integer) → eoption integer =
λ (env: closure_env) (y: integer) →
ESome
match
(match (from_closure_env env).0 with
| ENone _ → ENone _
| ESome x → if x then ESome y else ESome - y)
with
| ENone _ → raise NoValueProvided
| ESome f → f
let scope S
(S_in: S_in {x_in: eoption bool})
: S {f: eoption ((closure_env, integer) → eoption integer * closure_env)}
@ -18,16 +36,8 @@ let scope S
let get x : eoption bool = S_in.x_in in
let set f :
eoption ((closure_env, integer) → eoption integer * closure_env) =
ESome
(λ (env: closure_env) (y: integer) →
ESome
match
(match (from_closure_env env).0 with
| ENone _ → ENone _
| ESome x → if x then ESome y else ESome - y)
with
| ENone _ → raise NoValueProvided
| ESome f → f, to_closure_env (x))
ESome (closure_f, to_closure_env (x))
in
return { S f = f; }
```