mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Closure hoisting (missing a bug on hardest case)
This commit is contained in:
parent
2c45ca1599
commit
a20adc0055
@ -10,6 +10,8 @@ RUN sudo apk add python3
|
||||
RUN sudo ln -s /usr/bin/python3 /usr/bin/python
|
||||
RUN sudo apk add g++
|
||||
RUN sudo apk add make
|
||||
# We also need bash to build JaneStreet's base
|
||||
RUN sudo apk add bash
|
||||
|
||||
RUN mkdir catala
|
||||
WORKDIR catala
|
||||
|
@ -443,6 +443,13 @@ let driver source_file (options : Cli.options) : int =
|
||||
Message.emit_debug "Performing closure conversion...";
|
||||
let prgm = Lcalc.Closure_conversion.closure_conversion prgm in
|
||||
let prgm = Bindlib.unbox prgm in
|
||||
(* let _output_file, with_output = get_output_format () in
|
||||
with_output @@ fun fmt -> if Option.is_some options.ex_scope
|
||||
then Format.fprintf fmt "%a\n" (Shared_ast.Print.scope
|
||||
~debug:options.debug prgm.decl_ctx) (scope_uid,
|
||||
Shared_ast.Program.get_scope_body prgm scope_uid) else
|
||||
Format.fprintf fmt "%a\n" (Shared_ast.Print.program
|
||||
~debug:options.debug) prgm; *)
|
||||
let prgm =
|
||||
if options.optimize then (
|
||||
Message.emit_debug "Optimizing lambda calculus...";
|
||||
|
@ -19,9 +19,6 @@ open Shared_ast
|
||||
open Ast
|
||||
module D = Dcalc.Ast
|
||||
|
||||
(** TODO: This version is not yet debugged and ought to be specialized when
|
||||
Lcalc has more structure. *)
|
||||
|
||||
type 'm ctx = {
|
||||
decl_ctx : decl_ctx;
|
||||
name_context : string;
|
||||
@ -30,65 +27,7 @@ type 'm ctx = {
|
||||
|
||||
let tys_as_tanys tys = List.map (fun x -> Mark.map (fun _ -> TAny) x) tys
|
||||
|
||||
type 'm hoisted_closure = {
|
||||
name : 'm expr Var.t;
|
||||
closure : 'm expr (* Starts with [EAbs]. *);
|
||||
}
|
||||
|
||||
let rec hoist_context_free_closures :
|
||||
type m. m ctx -> m expr -> m hoisted_closure list * m expr boxed =
|
||||
fun ctx e ->
|
||||
let m = Mark.get e in
|
||||
match Mark.remove e with
|
||||
| EMatch { e; cases; name } ->
|
||||
let collected_closures, new_e = (hoist_context_free_closures ctx) e in
|
||||
(* We do not close the closures inside the arms of the match expression,
|
||||
since they get a special treatment at compilation to Scalc. *)
|
||||
let collected_closures, new_cases =
|
||||
EnumConstructor.Map.fold
|
||||
(fun cons e1 (collected_closures, new_cases) ->
|
||||
match Mark.remove e1 with
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let new_collected_closures, new_body =
|
||||
(hoist_context_free_closures ctx) body
|
||||
in
|
||||
let new_binder = Expr.bind vars new_body in
|
||||
( collected_closures @ new_collected_closures,
|
||||
EnumConstructor.Map.add cons
|
||||
(Expr.eabs new_binder tys (Mark.get e1))
|
||||
new_cases )
|
||||
| _ -> failwith "should not happen")
|
||||
cases
|
||||
(collected_closures, EnumConstructor.Map.empty)
|
||||
in
|
||||
collected_closures, Expr.ematch new_e name new_cases m
|
||||
| EApp { f = EAbs { binder; tys }, e1_pos; args } ->
|
||||
(* let-binding, we should not close these *)
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let collected_closures, new_body = (hoist_context_free_closures ctx) body in
|
||||
let new_binder = Expr.bind vars new_body in
|
||||
let collected_closures, new_args =
|
||||
List.fold_right
|
||||
(fun arg (collected_closures, new_args) ->
|
||||
let new_collected_closures, new_arg =
|
||||
(hoist_context_free_closures ctx) arg
|
||||
in
|
||||
collected_closures @ new_collected_closures, new_arg :: new_args)
|
||||
args (collected_closures, [])
|
||||
in
|
||||
( collected_closures,
|
||||
Expr.eapp (Expr.eabs new_binder (tys_as_tanys tys) e1_pos) new_args m )
|
||||
| EAbs _ ->
|
||||
(* this is the closure we want to hoist*)
|
||||
let closure_var = Var.make ctx.name_context in
|
||||
[{ name = closure_var; closure = e }], Expr.make_var closure_var m
|
||||
| EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
|
||||
| EArray _ | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _ | ECatch _
|
||||
| EVar _ ->
|
||||
Expr.map_gather ~acc:[] ~join:( @ ) ~f:(hoist_context_free_closures ctx) e
|
||||
| _ -> .
|
||||
[@@warning "-32"]
|
||||
(** { 1 Transforming closures}*)
|
||||
|
||||
(** Returns the expression with closed closures and the set of free variables
|
||||
inside this new expression. Implementation guided by
|
||||
@ -294,7 +233,7 @@ let rec transform_closures_expr :
|
||||
|
||||
(* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the
|
||||
type *)
|
||||
let closure_conversion_scope_let ctx scope_body_expr =
|
||||
let transform_closures_scope_let ctx scope_body_expr =
|
||||
Scope.fold_right_lets
|
||||
~f:(fun scope_let var_next acc ->
|
||||
let _free_vars, new_scope_let_expr =
|
||||
@ -323,7 +262,7 @@ let closure_conversion_scope_let ctx scope_body_expr =
|
||||
(Expr.Box.lift new_scope_let_expr))
|
||||
scope_body_expr
|
||||
|
||||
let closure_conversion (p : 'm program) : 'm program Bindlib.box =
|
||||
let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
|
||||
let _, new_code_items =
|
||||
Scope.fold_map
|
||||
~f:(fun toplevel_vars var code_item ->
|
||||
@ -340,7 +279,7 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
|
||||
}
|
||||
in
|
||||
let new_scope_lets =
|
||||
closure_conversion_scope_let ctx scope_body_expr
|
||||
transform_closures_scope_let ctx scope_body_expr
|
||||
in
|
||||
let new_scope_body_expr =
|
||||
Bindlib.bind_var scope_input_var new_scope_lets
|
||||
@ -415,3 +354,200 @@ let closure_conversion (p : 'm program) : 'm program Bindlib.box =
|
||||
(fun new_code_items ->
|
||||
{ code_items = new_code_items; decl_ctx = new_decl_ctx })
|
||||
new_code_items
|
||||
|
||||
(** {1 Hoisting closures}*)
|
||||
|
||||
type 'm hoisted_closure = {
|
||||
name : 'm expr Var.t;
|
||||
ty : typ;
|
||||
closure : 'm expr (* Starts with [EAbs]. *);
|
||||
}
|
||||
|
||||
let rec hoist_closures_expr :
|
||||
type m.
|
||||
string -> m expr -> m hoisted_closure Bindlib.box list * m expr boxed =
|
||||
fun name_context e ->
|
||||
let m = Mark.get e in
|
||||
match Mark.remove e with
|
||||
| EMatch { e; cases; name } ->
|
||||
let collected_closures, new_e = (hoist_closures_expr name_context) e in
|
||||
(* We do not close the closures inside the arms of the match expression,
|
||||
since they get a special treatment at compilation to Scalc. *)
|
||||
let collected_closures, new_cases =
|
||||
EnumConstructor.Map.fold
|
||||
(fun cons e1 (collected_closures, new_cases) ->
|
||||
match Mark.remove e1 with
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let new_collected_closures, new_body =
|
||||
(hoist_closures_expr name_context) body
|
||||
in
|
||||
let new_binder = Expr.bind vars new_body in
|
||||
( collected_closures @ new_collected_closures,
|
||||
EnumConstructor.Map.add cons
|
||||
(Expr.eabs new_binder tys (Mark.get e1))
|
||||
new_cases )
|
||||
| _ -> failwith "should not happen")
|
||||
cases
|
||||
(collected_closures, EnumConstructor.Map.empty)
|
||||
in
|
||||
collected_closures, Expr.ematch new_e name new_cases m
|
||||
| EApp { f = EAbs { binder; tys }, e1_pos; args } ->
|
||||
(* let-binding, we should not close these *)
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let collected_closures, new_body =
|
||||
(hoist_closures_expr name_context) body
|
||||
in
|
||||
let new_binder = Expr.bind vars new_body in
|
||||
let collected_closures, new_args =
|
||||
List.fold_right
|
||||
(fun arg (collected_closures, new_args) ->
|
||||
let new_collected_closures, new_arg =
|
||||
(hoist_closures_expr name_context) arg
|
||||
in
|
||||
collected_closures @ new_collected_closures, new_arg :: new_args)
|
||||
args (collected_closures, [])
|
||||
in
|
||||
( collected_closures,
|
||||
Expr.eapp (Expr.eabs new_binder (tys_as_tanys tys) e1_pos) new_args m )
|
||||
| EApp
|
||||
{
|
||||
f =
|
||||
(EOp { op = HandleDefaultOpt | Fold | Map | Filter | Reduce; _ }, _)
|
||||
as f;
|
||||
args;
|
||||
} ->
|
||||
(* Special case for some operators: its arguments closures thunks because if
|
||||
you want to extract it as a function you need these closures to preserve
|
||||
evaluation order, but backends that don't support closures will simply
|
||||
extract these operators in a inlined way and skip the thunks. *)
|
||||
let collected_closures, new_args =
|
||||
List.fold_right
|
||||
(fun (arg : (lcalc, m) gexpr) (collected_closures, new_args) ->
|
||||
let m_arg = Mark.get arg in
|
||||
match Mark.remove arg with
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, arg = Bindlib.unmbind binder in
|
||||
let new_collected_closures, new_arg =
|
||||
(hoist_closures_expr name_context) arg
|
||||
in
|
||||
let new_arg =
|
||||
Expr.make_abs vars new_arg tys (Expr.mark_pos m_arg)
|
||||
in
|
||||
new_collected_closures @ collected_closures, new_arg :: new_args
|
||||
| _ ->
|
||||
let new_collected_closures, new_arg =
|
||||
hoist_closures_expr name_context arg
|
||||
in
|
||||
new_collected_closures @ collected_closures, new_arg :: new_args)
|
||||
args ([], [])
|
||||
in
|
||||
collected_closures, Expr.eapp (Expr.box f) new_args (Mark.get e)
|
||||
| EAbs { tys; _ } ->
|
||||
(* this is the closure we want to hoist*)
|
||||
let closure_var = Var.make ("closure_" ^ name_context) in
|
||||
( [
|
||||
Bindlib.box_apply
|
||||
(fun e ->
|
||||
{
|
||||
name = closure_var;
|
||||
ty = TArrow (tys, (TAny, Expr.mark_pos m)), Expr.mark_pos m;
|
||||
closure = e, m;
|
||||
})
|
||||
(fst (Expr.box e));
|
||||
],
|
||||
Expr.make_var closure_var m )
|
||||
| EApp _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _
|
||||
| EArray _ | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _ | ECatch _
|
||||
| EVar _ ->
|
||||
Expr.map_gather ~acc:[] ~join:( @ ) ~f:(hoist_closures_expr name_context) e
|
||||
| _ -> .
|
||||
[@@warning "-32"]
|
||||
|
||||
(* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the
|
||||
type *)
|
||||
let hoist_closures_scope_let name_context scope_body_expr =
|
||||
Scope.fold_right_lets
|
||||
~f:(fun scope_let var_next (hoisted_closures, next_scope_lets) ->
|
||||
let new_hoisted_closures, new_scope_let_expr =
|
||||
(hoist_closures_expr (Bindlib.name_of var_next))
|
||||
scope_let.scope_let_expr
|
||||
in
|
||||
( new_hoisted_closures @ hoisted_closures,
|
||||
Bindlib.box_apply2
|
||||
(fun scope_let_next scope_let_expr ->
|
||||
ScopeLet { scope_let with scope_let_next; scope_let_expr })
|
||||
(Bindlib.bind_var var_next next_scope_lets)
|
||||
(Expr.Box.lift new_scope_let_expr) ))
|
||||
~init:(fun res ->
|
||||
let hoisted_closures, new_scope_let_expr =
|
||||
(hoist_closures_expr name_context) res
|
||||
in
|
||||
(* INVARIANT here: the result expr of a scope is simply a struct
|
||||
containing all output variables so nothing should be converted here, so
|
||||
no need to take into account free variables. *)
|
||||
( hoisted_closures,
|
||||
Bindlib.box_apply
|
||||
(fun res -> Result res)
|
||||
(Expr.Box.lift new_scope_let_expr) ))
|
||||
scope_body_expr
|
||||
|
||||
let hoist_closures_program (p : 'm program) : 'm program Bindlib.box =
|
||||
let hoisted_closures, new_code_items =
|
||||
Scope.fold_map
|
||||
~f:(fun hoisted_closures _var code_item ->
|
||||
match code_item with
|
||||
| ScopeDef (name, body) ->
|
||||
let scope_input_var, scope_body_expr =
|
||||
Bindlib.unbind body.scope_body_expr
|
||||
in
|
||||
let new_hoisted_closures, new_scope_lets =
|
||||
hoist_closures_scope_let
|
||||
(fst (ScopeName.get_info name))
|
||||
scope_body_expr
|
||||
in
|
||||
let new_scope_body_expr =
|
||||
Bindlib.bind_var scope_input_var new_scope_lets
|
||||
in
|
||||
( new_hoisted_closures @ hoisted_closures,
|
||||
Bindlib.box_apply
|
||||
(fun scope_body_expr ->
|
||||
ScopeDef (name, { body with scope_body_expr }))
|
||||
new_scope_body_expr )
|
||||
| Topdef (name, ty, expr) ->
|
||||
let new_hoisted_closures, new_expr =
|
||||
hoist_closures_expr (Mark.remove (TopdefName.get_info name)) expr
|
||||
in
|
||||
( new_hoisted_closures @ hoisted_closures,
|
||||
Bindlib.box_apply
|
||||
(fun e -> Topdef (name, ty, e))
|
||||
(Expr.Box.lift new_expr) ))
|
||||
~varf:(fun v -> v)
|
||||
[] p.code_items
|
||||
in
|
||||
Bindlib.box_apply
|
||||
(fun hoisted_closures ->
|
||||
let new_code_items =
|
||||
List.fold_left
|
||||
(fun (new_code_items : _ gexpr code_item_list Bindlib.box) hc ->
|
||||
let next = Bindlib.bind_var hc.name new_code_items in
|
||||
Bindlib.box_apply
|
||||
(fun next ->
|
||||
Cons
|
||||
( Topdef
|
||||
( TopdefName.fresh
|
||||
(Bindlib.name_of hc.name, Expr.pos hc.closure),
|
||||
hc.ty,
|
||||
hc.closure ),
|
||||
next ))
|
||||
next)
|
||||
new_code_items hoisted_closures
|
||||
in
|
||||
{ p with code_items = Bindlib.unbox new_code_items })
|
||||
(Bindlib.box_list hoisted_closures)
|
||||
|
||||
(** { 1 Closure conversion }*)
|
||||
|
||||
let closure_conversion (p : 'm program) : 'm program Bindlib.box =
|
||||
let new_p = transform_closures_program p in
|
||||
hoist_closures_program (Bindlib.unbox new_p)
|
||||
|
@ -855,16 +855,15 @@ let code_item ?(debug = false) decl_ctx fmt c =
|
||||
match c with
|
||||
| ScopeDef (n, b) -> scope ~debug decl_ctx fmt (n, b)
|
||||
| Topdef (n, ty, e) ->
|
||||
Format.fprintf fmt "@[%a %a %a %a %a %a @]" keyword "let topval"
|
||||
TopdefName.format_t n op_style ":" (typ decl_ctx) ty op_style "="
|
||||
(expr ~debug ()) e
|
||||
Format.fprintf fmt "@[<v 2>@[<hov 2>%a@ %a@ %a@ %a@ %a@]@ %a@]" keyword
|
||||
"let topval" TopdefName.format_t n op_style ":" (typ decl_ctx) ty op_style
|
||||
"=" (expr ~debug ()) e
|
||||
|
||||
let rec code_item_list ?(debug = false) decl_ctx fmt c =
|
||||
match c with
|
||||
| Nil -> ()
|
||||
| Cons (c, b) ->
|
||||
let _x, cl = Bindlib.unbind b in
|
||||
|
||||
Format.fprintf fmt "%a @.%a"
|
||||
(code_item ~debug decl_ctx)
|
||||
c
|
||||
|
@ -951,7 +951,7 @@ let rec scopes ~leave_unresolved ctx env = function
|
||||
let e' = Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e' in
|
||||
( Env.add var uf env,
|
||||
Bindlib.box_apply
|
||||
(fun e -> A.Topdef (name, typ, e))
|
||||
(fun e -> A.Topdef (name, Expr.ty e', e))
|
||||
(Expr.Box.lift e') )
|
||||
in
|
||||
let next', env = scopes ~leave_unresolved ctx env next in
|
||||
|
@ -12,21 +12,28 @@ scope S:
|
||||
```
|
||||
|
||||
```catala-test-inline
|
||||
$ catala Lcalc -s S --avoid_exceptions -O --closure_conversion
|
||||
$ catala Lcalc --avoid_exceptions -O --closure_conversion
|
||||
type eoption = | ENone of unit | ESome of any
|
||||
|
||||
type S = { z: eoption integer; }
|
||||
|
||||
type S_in = { x_in: eoption bool; }
|
||||
|
||||
let topval closure_f : (closure_env, integer) → eoption integer =
|
||||
λ (env: closure_env) (y: integer) →
|
||||
ESome
|
||||
match
|
||||
(match (from_closure_env env).0 with
|
||||
| ENone _ → ENone _
|
||||
| ESome x → if x then ESome y else ESome - y)
|
||||
with
|
||||
| ENone _ → raise NoValueProvided
|
||||
| ESome f → f
|
||||
let scope S (S_in: S_in {x_in: eoption bool}): S {z: eoption integer} =
|
||||
let get x : eoption bool = S_in.x_in in
|
||||
let set f :
|
||||
eoption ((closure_env, integer) → eoption integer * closure_env) =
|
||||
ESome
|
||||
(λ (env: closure_env) (y: integer) →
|
||||
ESome
|
||||
match
|
||||
(match (from_closure_env env).0 with
|
||||
| ENone _ → ENone _
|
||||
| ESome x → if x then ESome y else ESome - y)
|
||||
with
|
||||
| ENone _ → raise NoValueProvided
|
||||
| ESome f → f, to_closure_env (x))
|
||||
ESome (closure_f, to_closure_env (x))
|
||||
in
|
||||
let set z : eoption integer =
|
||||
ESome
|
||||
@ -44,4 +51,5 @@ let scope S (S_in: S_in {x_in: eoption bool}): S {z: eoption integer} =
|
||||
| ESome z → z
|
||||
in
|
||||
return { S z = z; }
|
||||
|
||||
```
|
||||
|
@ -10,7 +10,25 @@ scope S:
|
||||
```
|
||||
|
||||
```catala-test-inline
|
||||
$ catala Lcalc -s S --avoid_exceptions -O --closure_conversion
|
||||
$ catala Lcalc --avoid_exceptions -O --closure_conversion
|
||||
type eoption = | ENone of unit | ESome of any
|
||||
|
||||
type S = {
|
||||
f: eoption ((closure_env, integer) → eoption integer * closure_env);
|
||||
}
|
||||
|
||||
type S_in = { x_in: eoption bool; }
|
||||
|
||||
let topval closure_f : (closure_env, integer) → eoption integer =
|
||||
λ (env: closure_env) (y: integer) →
|
||||
ESome
|
||||
match
|
||||
(match (from_closure_env env).0 with
|
||||
| ENone _ → ENone _
|
||||
| ESome x → if x then ESome y else ESome - y)
|
||||
with
|
||||
| ENone _ → raise NoValueProvided
|
||||
| ESome f → f
|
||||
let scope S
|
||||
(S_in: S_in {x_in: eoption bool})
|
||||
: S {f: eoption ((closure_env, integer) → eoption integer * closure_env)}
|
||||
@ -18,16 +36,8 @@ let scope S
|
||||
let get x : eoption bool = S_in.x_in in
|
||||
let set f :
|
||||
eoption ((closure_env, integer) → eoption integer * closure_env) =
|
||||
ESome
|
||||
(λ (env: closure_env) (y: integer) →
|
||||
ESome
|
||||
match
|
||||
(match (from_closure_env env).0 with
|
||||
| ENone _ → ENone _
|
||||
| ESome x → if x then ESome y else ESome - y)
|
||||
with
|
||||
| ENone _ → raise NoValueProvided
|
||||
| ESome f → f, to_closure_env (x))
|
||||
ESome (closure_f, to_closure_env (x))
|
||||
in
|
||||
return { S f = f; }
|
||||
|
||||
```
|
||||
|
Loading…
Reference in New Issue
Block a user