Generalise the definition of lists of nested binders

This commit is contained in:
Louis Gesbert 2024-02-09 16:48:02 +01:00
parent c124943a6e
commit e308ff8d02
31 changed files with 724 additions and 869 deletions

View File

@ -642,14 +642,14 @@ let translate_rule
( (fun next -> ( (fun next ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun next merged_expr -> (fun next merged_expr ->
ScopeLet Cons
{ ( {
scope_let_next = next;
scope_let_typ = tau; scope_let_typ = tau;
scope_let_expr = merged_expr; scope_let_expr = merged_expr;
scope_let_kind = ScopeVarDefinition; scope_let_kind = ScopeVarDefinition;
scope_let_pos = Mark.get a; scope_let_pos = Mark.get a;
}) },
next ))
(Bindlib.bind_var a_var next) (Bindlib.bind_var a_var next)
(Expr.Box.lift merged_expr)), (Expr.Box.lift merged_expr)),
{ {
@ -691,14 +691,14 @@ let translate_rule
( (fun next -> ( (fun next ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun next thunked_or_nonempty_new_e -> (fun next thunked_or_nonempty_new_e ->
ScopeLet Cons
{ ( {
scope_let_next = next;
scope_let_pos = Mark.get a_name; scope_let_pos = Mark.get a_name;
scope_let_typ = input_var_typ (Mark.remove tau) a_io; scope_let_typ = input_var_typ (Mark.remove tau) a_io;
scope_let_expr = thunked_or_nonempty_new_e; scope_let_expr = thunked_or_nonempty_new_e;
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
}) },
next ))
(Bindlib.bind_var a_var next) (Bindlib.bind_var a_var next)
(Expr.Box.lift thunked_or_nonempty_new_e)), (Expr.Box.lift thunked_or_nonempty_new_e)),
{ {
@ -836,14 +836,14 @@ let translate_rule
let call_scope_let next = let call_scope_let next =
Bindlib.box_apply2 Bindlib.box_apply2
(fun next call_expr -> (fun next call_expr ->
ScopeLet Cons
{ ( {
scope_let_next = next;
scope_let_pos = pos_sigma; scope_let_pos = pos_sigma;
scope_let_kind = CallingSubScope; scope_let_kind = CallingSubScope;
scope_let_typ = result_tuple_typ; scope_let_typ = result_tuple_typ;
scope_let_expr = call_expr; scope_let_expr = call_expr;
}) },
next ))
(Bindlib.bind_var result_tuple_var next) (Bindlib.bind_var result_tuple_var next)
(Expr.Box.lift call_expr) (Expr.Box.lift call_expr)
in in
@ -856,9 +856,8 @@ let translate_rule
in in
Bindlib.box_apply2 Bindlib.box_apply2
(fun next r -> (fun next r ->
ScopeLet Cons
{ ( {
scope_let_next = next;
scope_let_pos = pos_sigma; scope_let_pos = pos_sigma;
scope_let_typ = var_ctx.scope_var_typ, pos_sigma; scope_let_typ = var_ctx.scope_var_typ, pos_sigma;
scope_let_kind = DestructuringSubScopeResults; scope_let_kind = DestructuringSubScopeResults;
@ -866,7 +865,8 @@ let translate_rule
( EStructAccess ( EStructAccess
{ name = called_scope_return_struct; e = r; field }, { name = called_scope_return_struct; e = r; field },
mark_tany m pos_sigma ); mark_tany m pos_sigma );
}) },
next ))
(Bindlib.bind_var v next) (Bindlib.bind_var v next)
(Expr.Box.lift (Expr.Box.lift
(Expr.make_var result_tuple_var (mark_tany m pos_sigma)))) (Expr.make_var result_tuple_var (mark_tany m pos_sigma))))
@ -892,9 +892,8 @@ let translate_rule
( (fun next -> ( (fun next ->
Bindlib.box_apply2 Bindlib.box_apply2
(fun next new_e -> (fun next new_e ->
ScopeLet Cons
{ ( {
scope_let_next = next;
scope_let_pos; scope_let_pos;
scope_let_typ; scope_let_typ;
scope_let_expr = scope_let_expr =
@ -902,7 +901,8 @@ let translate_rule
(Expr.map_ty (fun _ -> scope_let_typ) (Mark.get e)) (Expr.map_ty (fun _ -> scope_let_typ) (Mark.get e))
(EAssert new_e); (EAssert new_e);
scope_let_kind = Assertion; scope_let_kind = Assertion;
}) },
next ))
(Bindlib.bind_var (Var.make "_") next) (Bindlib.bind_var (Var.make "_") next)
(Expr.Box.lift new_e)), (Expr.Box.lift new_e)),
ctx ) ctx )
@ -944,7 +944,7 @@ let translate_rules
in in
( scope_lets ( scope_lets
(Bindlib.box_apply (Bindlib.box_apply
(fun return_exp -> Result return_exp) (fun return_exp -> Last return_exp)
(Expr.Box.lift return_exp)), (Expr.Box.lift return_exp)),
new_ctx ) new_ctx )
@ -1042,10 +1042,9 @@ let translate_scope_decl
in in
Bindlib.box_apply2 Bindlib.box_apply2
(fun next r -> (fun next r ->
ScopeLet Cons
{ ( {
scope_let_kind = DestructuringInputStruct; scope_let_kind = DestructuringInputStruct;
scope_let_next = next;
scope_let_pos = pos_sigma; scope_let_pos = pos_sigma;
scope_let_typ = scope_let_typ =
input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io; input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
@ -1053,7 +1052,8 @@ let translate_scope_decl
( EStructAccess ( EStructAccess
{ name = scope_input_struct_name; e = r; field }, { name = scope_input_struct_name; e = r; field },
mark_tany scope_mark pos_sigma ); mark_tany scope_mark pos_sigma );
}) },
next ))
(Bindlib.bind_var v next) (Bindlib.bind_var v next)
(Expr.Box.lift (Expr.Box.lift
(Expr.make_var scope_input_var (mark_tany scope_mark pos_sigma)))) (Expr.make_var scope_input_var (mark_tany scope_mark pos_sigma))))
@ -1182,7 +1182,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
ending with the top-level scope. The decl_ctx is filled in left-to-right ending with the top-level scope. The decl_ctx is filled in left-to-right
order, then the chained scopes aggregated from the right. *) order, then the chained scopes aggregated from the right. *)
let rec translate_defs = function let rec translate_defs = function
| [] -> Bindlib.box Nil | [] -> Bindlib.box (Last ())
| def :: next -> | def :: next ->
let dvar, def = let dvar, def =
match def with match def with

View File

@ -22,38 +22,24 @@ type invariant_status = Fail | Pass | Ignore
type invariant_expr = decl_ctx -> typed expr -> invariant_status type invariant_expr = decl_ctx -> typed expr -> invariant_status
let check_invariant (inv : string * invariant_expr) (p : typed program) : bool = let check_invariant (inv : string * invariant_expr) (p : typed program) : bool =
(* TODO: add a Program.fold_left_map_exprs to get rid of the mutable
reference *)
let result = ref true in
let name, inv = inv in let name, inv = inv in
let total = ref 0 in let result, total, ok =
let ok = ref 0 in Program.fold_exprs p ~init:(true, 0, 0) ~f:(fun acc e _ty ->
let p' =
Program.map_exprs p ~varf:Fun.id ~f:(fun e ->
(* let currente = e in *) (* let currente = e in *)
let rec f e = let rec f e (result, total, ok) =
let r = let result, total, ok = Expr.shallow_fold f e (result, total, ok) in
match inv p.decl_ctx e with match inv p.decl_ctx e with
| Ignore -> true | Ignore -> result, total, ok
| Fail -> | Fail ->
Message.raise_spanned_error (Expr.pos e) Message.raise_spanned_error (Expr.pos e)
"@[<v 2>Invariant @{<magenta>%s@} failed.@,%a@]" name "@[<v 2>Invariant @{<magenta>%s@} failed.@,%a@]" name
(Print.expr ()) e (Print.expr ()) e
| Pass -> | Pass -> result, total + 1, ok + 1
incr ok;
incr total;
true
in in
Expr.map_gather e ~acc:r ~join:( && ) ~f f e acc)
in in
Message.emit_debug "Invariant %s checked.@ result: [%d/%d]" name ok total;
let res, e' = f e in result
result := res && !result;
e')
in
assert (Bindlib.free_vars p' = Bindlib.empty_ctxt);
Message.emit_debug "Invariant %s checked.@ result: [%d/%d]" name !ok !total;
!result
(* Structural invariant: no default can have as type A -> B *) (* Structural invariant: no default can have as type A -> B *)
let invariant_default_no_arrow () : string * invariant_expr = let invariant_default_no_arrow () : string * invariant_expr =

View File

@ -586,14 +586,12 @@ module Commands = struct
let scope_uid = get_scope_uid prg.decl_ctx scope in let scope_uid = get_scope_uid prg.decl_ctx scope in
Print.scope ~debug:options.Cli.debug prg.decl_ctx fmt Print.scope ~debug:options.Cli.debug prg.decl_ctx fmt
( scope_uid, ( scope_uid,
Option.get BoundList.find
(Scope.fold_left ~init:None ~f:(function
~f:(fun acc def _ ->
match def with
| ScopeDef (name, body) when ScopeName.equal name scope_uid -> | ScopeDef (name, body) when ScopeName.equal name scope_uid ->
Some body Some body
| _ -> acc) | _ -> None)
prg.code_items) ); prg.code_items );
Format.pp_print_newline fmt () Format.pp_print_newline fmt ()
| None -> | None ->
let scope_uid = get_random_scope_uid prg.decl_ctx in let scope_uid = get_random_scope_uid prg.decl_ctx in

View File

@ -286,10 +286,9 @@ let rec transform_closures_expr :
new_e1 call_expr (Expr.pos e) ) new_e1 call_expr (Expr.pos e) )
| _ -> . | _ -> .
(* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the (* Can't reuse Scope.map because we inspect the bind variables *)
type *)
let transform_closures_scope_let ctx scope_body_expr = let transform_closures_scope_let ctx scope_body_expr =
Scope.fold_right_lets BoundList.fold_right
~f:(fun scope_let var_next acc -> ~f:(fun scope_let var_next acc ->
let _free_vars, new_scope_let_expr = let _free_vars, new_scope_let_expr =
(transform_closures_expr (transform_closures_expr
@ -298,13 +297,13 @@ let transform_closures_scope_let ctx scope_body_expr =
in in
Bindlib.box_apply2 Bindlib.box_apply2
(fun scope_let_next scope_let_expr -> (fun scope_let_next scope_let_expr ->
ScopeLet Cons
{ ( {
scope_let with scope_let with
scope_let_next;
scope_let_expr; scope_let_expr;
scope_let_typ = Mark.copy scope_let.scope_let_typ TAny; scope_let_typ = Mark.copy scope_let.scope_let_typ TAny;
}) },
scope_let_next ))
(Bindlib.bind_var var_next acc) (Bindlib.bind_var var_next acc)
(Expr.Box.lift new_scope_let_expr)) (Expr.Box.lift new_scope_let_expr))
~init:(fun res -> ~init:(fun res ->
@ -312,14 +311,12 @@ let transform_closures_scope_let ctx scope_body_expr =
(* INVARIANT here: the result expr of a scope is simply a struct (* INVARIANT here: the result expr of a scope is simply a struct
containing all output variables so nothing should be converted here, so containing all output variables so nothing should be converted here, so
no need to take into account free variables. *) no need to take into account free variables. *)
Bindlib.box_apply Bindlib.box_apply (fun e -> Last e) (Expr.Box.lift new_scope_let_expr))
(fun res -> Result res)
(Expr.Box.lift new_scope_let_expr))
scope_body_expr scope_body_expr
let transform_closures_program (p : 'm program) : 'm program Bindlib.box = let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
let _, new_code_items = let (), new_code_items =
Scope.fold_map BoundList.fold_map
~f:(fun toplevel_vars var code_item -> ~f:(fun toplevel_vars var code_item ->
match code_item with match code_item with
| ScopeDef (name, body) -> | ScopeDef (name, body) ->
@ -346,6 +343,7 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
pos ) pos )
in in
( Var.Map.add var ty toplevel_vars, ( Var.Map.add var ty toplevel_vars,
var,
Bindlib.box_apply Bindlib.box_apply
(fun scope_body_expr -> (fun scope_body_expr ->
ScopeDef (name, { body with scope_body_expr })) ScopeDef (name, { body with scope_body_expr }))
@ -361,6 +359,7 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
let _free_vars, new_expr = transform_closures_expr ctx expr in let _free_vars, new_expr = transform_closures_expr ctx expr in
let new_binder = Expr.bind v new_expr in let new_binder = Expr.bind v new_expr in
( Var.Map.add var ty toplevel_vars, ( Var.Map.add var ty toplevel_vars,
var,
Bindlib.box_apply Bindlib.box_apply
(fun e -> Topdef (name, ty, e)) (fun e -> Topdef (name, ty, e))
(Expr.Box.lift (Expr.eabs new_binder tys m)) ) (Expr.Box.lift (Expr.eabs new_binder tys m)) )
@ -373,11 +372,12 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
in in
let _free_vars, new_expr = transform_closures_expr ctx expr in let _free_vars, new_expr = transform_closures_expr ctx expr in
( Var.Map.add var ty toplevel_vars, ( Var.Map.add var ty toplevel_vars,
var,
Bindlib.box_apply Bindlib.box_apply
(fun e -> Topdef (name, (TAny, Mark.get ty), e)) (fun e -> Topdef (name, (TAny, Mark.get ty), e))
(Expr.Box.lift new_expr) )) (Expr.Box.lift new_expr) ))
~varf:(fun v -> v) ~last:(fun _ () -> (), Bindlib.box ())
Var.Map.empty p.code_items ~init:Var.Map.empty p.code_items
in in
(* Now we need to further tweak [decl_ctx] because some of the user-defined (* Now we need to further tweak [decl_ctx] because some of the user-defined
types can have closures in them and these closured might have changed type. types can have closures in them and these closured might have changed type.
@ -550,7 +550,7 @@ let rec hoist_closures_expr :
(* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the (* Here I have to reimplement Scope.map_exprs_in_lets because I'm changing the
type *) type *)
let hoist_closures_scope_let name_context scope_body_expr = let hoist_closures_scope_let name_context scope_body_expr =
Scope.fold_right_lets BoundList.fold_right
~f:(fun scope_let var_next (hoisted_closures, next_scope_lets) -> ~f:(fun scope_let var_next (hoisted_closures, next_scope_lets) ->
let new_hoisted_closures, new_scope_let_expr = let new_hoisted_closures, new_scope_let_expr =
(hoist_closures_expr (Bindlib.name_of var_next)) (hoist_closures_expr (Bindlib.name_of var_next))
@ -559,7 +559,7 @@ let hoist_closures_scope_let name_context scope_body_expr =
( new_hoisted_closures @ hoisted_closures, ( new_hoisted_closures @ hoisted_closures,
Bindlib.box_apply2 Bindlib.box_apply2
(fun scope_let_next scope_let_expr -> (fun scope_let_next scope_let_expr ->
ScopeLet { scope_let with scope_let_next; scope_let_expr }) Cons ({ scope_let with scope_let_expr }, scope_let_next))
(Bindlib.bind_var var_next next_scope_lets) (Bindlib.bind_var var_next next_scope_lets)
(Expr.Box.lift new_scope_let_expr) )) (Expr.Box.lift new_scope_let_expr) ))
~init:(fun res -> ~init:(fun res ->
@ -571,7 +571,7 @@ let hoist_closures_scope_let name_context scope_body_expr =
no need to take into account free variables. *) no need to take into account free variables. *)
( hoisted_closures, ( hoisted_closures,
Bindlib.box_apply Bindlib.box_apply
(fun res -> Result res) (fun res -> Last res)
(Expr.Box.lift new_scope_let_expr) )) (Expr.Box.lift new_scope_let_expr) ))
scope_body_expr scope_body_expr
@ -579,7 +579,7 @@ let rec hoist_closures_code_item_list
(code_items : (lcalc, 'm) gexpr code_item_list) : (code_items : (lcalc, 'm) gexpr code_item_list) :
(lcalc, 'm) gexpr code_item_list Bindlib.box = (lcalc, 'm) gexpr code_item_list Bindlib.box =
match code_items with match code_items with
| Nil -> Bindlib.box Nil | Last () -> Bindlib.box (Last ())
| Cons (code_item, next_code_items) -> | Cons (code_item, next_code_items) ->
let code_item_var, next_code_items = Bindlib.unbind next_code_items in let code_item_var, next_code_items = Bindlib.unbind next_code_items in
let hoisted_closures, new_code_item = let hoisted_closures, new_code_item =

View File

@ -92,59 +92,5 @@ and translate_expr (e : 'm D.expr) : 'm A.expr boxed =
Expr.map ~f:translate_expr ~typ:translate_typ e Expr.map ~f:translate_expr ~typ:translate_typ e
| _ -> . | _ -> .
let translate_scope_body_expr (scope_body_expr : 'expr1 scope_body_expr) :
'expr2 scope_body_expr Bindlib.box =
Scope.fold_right_lets
~f:(fun scope_let var_next acc ->
Bindlib.box_apply2
(fun scope_let_next scope_let_expr ->
ScopeLet
{
scope_let with
scope_let_next;
scope_let_expr;
scope_let_typ = translate_typ scope_let.scope_let_typ;
})
(Bindlib.bind_var (Var.translate var_next) acc)
(Expr.Box.lift (translate_expr scope_let.scope_let_expr)))
~init:(fun res ->
Bindlib.box_apply
(fun res -> Result res)
(Expr.Box.lift (translate_expr res)))
scope_body_expr
let translate_code_items scopes =
let f = function
| ScopeDef (name, body) ->
let scope_input_var, scope_lets = Bindlib.unbind body.scope_body_expr in
let new_body_expr = translate_scope_body_expr scope_lets in
let new_body_expr =
Bindlib.bind_var (Var.translate scope_input_var) new_body_expr
in
Bindlib.box_apply
(fun scope_body_expr -> ScopeDef (name, { body with scope_body_expr }))
new_body_expr
| Topdef (name, typ, expr) ->
Bindlib.box_apply
(fun e -> Topdef (name, typ, e))
(Expr.Box.lift (translate_expr expr))
in
Scope.map ~f ~varf:Var.translate scopes
let translate_program (prg : 'm D.program) : 'm A.program = let translate_program (prg : 'm D.program) : 'm A.program =
let code_items = Bindlib.unbox (translate_code_items prg.code_items) in Program.map_exprs prg ~typ:translate_typ ~varf:Var.translate ~f:translate_expr
let ctx_enums =
EnumName.Map.map
(EnumConstructor.Map.map translate_typ)
prg.decl_ctx.ctx_enums
in
let ctx_structs =
StructName.Map.map
(StructField.Map.map translate_typ)
prg.decl_ctx.ctx_structs
in
{
prg with
code_items;
decl_ctx = { prg.decl_ctx with ctx_enums; ctx_structs };
}

View File

@ -123,60 +123,5 @@ and translate_expr (e : 'm D.expr) : 'm A.expr boxed =
Expr.map ~f:translate_expr ~typ:translate_typ e Expr.map ~f:translate_expr ~typ:translate_typ e
| _ -> . | _ -> .
let translate_scope_body_expr
(scope_body_expr : (dcalc, 'm) gexpr scope_body_expr) :
(lcalc, 'm) gexpr scope_body_expr Bindlib.box =
Scope.fold_right_lets
~f:(fun scope_let var_next acc ->
Bindlib.box_apply2
(fun scope_let_next scope_let_expr ->
ScopeLet
{
scope_let with
scope_let_next;
scope_let_expr;
scope_let_typ = translate_typ scope_let.scope_let_typ;
})
(Bindlib.bind_var (Var.translate var_next) acc)
(Expr.Box.lift (translate_expr scope_let.scope_let_expr)))
~init:(fun res ->
Bindlib.box_apply
(fun res -> Result res)
(Expr.Box.lift (translate_expr res)))
scope_body_expr
let translate_code_items scopes =
let f = function
| ScopeDef (name, body) ->
let scope_input_var, scope_lets = Bindlib.unbind body.scope_body_expr in
let new_body_expr = translate_scope_body_expr scope_lets in
let new_body_expr =
Bindlib.bind_var (Var.translate scope_input_var) new_body_expr
in
Bindlib.box_apply
(fun scope_body_expr -> ScopeDef (name, { body with scope_body_expr }))
new_body_expr
| Topdef (name, typ, expr) ->
Bindlib.box_apply
(fun e -> Topdef (name, typ, e))
(Expr.Box.lift (translate_expr expr))
in
Scope.map ~f ~varf:Var.translate scopes
let translate_program (prg : 'm D.program) : 'm A.program = let translate_program (prg : 'm D.program) : 'm A.program =
let code_items = Bindlib.unbox (translate_code_items prg.code_items) in Program.map_exprs prg ~typ:translate_typ ~varf:Var.translate ~f:translate_expr
let ctx_enums =
EnumName.Map.map
(EnumConstructor.Map.map translate_typ)
prg.decl_ctx.ctx_enums
in
let ctx_structs =
StructName.Map.map
(StructField.Map.map translate_typ)
prg.decl_ctx.ctx_structs
in
{
prg with
code_items;
decl_ctx = { prg.decl_ctx with ctx_enums; ctx_structs };
}

View File

@ -46,6 +46,9 @@ type monomorphized_instances = {
arrays : array_instance Type.Map.t; arrays : array_instance Type.Map.t;
} }
let empty_instances =
{ options = Type.Map.empty; tuples = Type.Map.empty; arrays = Type.Map.empty }
let collect_monomorphized_instances (prg : typed program) : let collect_monomorphized_instances (prg : typed program) :
monomorphized_instances = monomorphized_instances =
let option_instances_counter = ref 0 in let option_instances_counter = ref 0 in
@ -157,23 +160,8 @@ let collect_monomorphized_instances (prg : typed program) :
Expr.shallow_fold collect_expr e (collect_typ acc (Expr.ty e)) Expr.shallow_fold collect_expr e (collect_typ acc (Expr.ty e))
in in
let acc = let acc =
Scope.fold_left Scope.fold_exprs prg.code_items ~init:empty_instances ~f:(fun acc e typ ->
~init: collect_typ (collect_expr e acc) typ)
{
options = Type.Map.empty;
tuples = Type.Map.empty;
arrays = Type.Map.empty;
}
~f:(fun acc item _ ->
match item with
| Topdef (_, typ, e) -> collect_typ (collect_expr e acc) typ
| ScopeDef (_, body) ->
let _, body = Bindlib.unbind body.scope_body_expr in
Scope.fold_left_lets ~init:acc
~f:(fun acc { scope_let_typ; scope_let_expr; _ } _ ->
collect_typ (collect_expr scope_let_expr acc) scope_let_typ)
body)
prg.code_items
in in
EnumName.Map.fold EnumName.Map.fold
(fun _ constructors acc -> (fun _ constructors acc ->
@ -301,46 +289,24 @@ let rec monomorphize_expr
let program (prg : typed program) : let program (prg : typed program) :
typed program * Scopelang.Dependency.TVertex.t list = typed program * Scopelang.Dependency.TVertex.t list =
let monomorphized_instances = collect_monomorphized_instances prg in let monomorphized_instances = collect_monomorphized_instances prg in
let decl_ctx = prg.decl_ctx in
(* First we remove the polymorphic option type *) (* First we remove the polymorphic option type *)
let prg = let ctx_enums = EnumName.Map.remove Expr.option_enum decl_ctx.ctx_enums in
{ let ctx_structs = decl_ctx.ctx_structs in
prg with
decl_ctx =
{
prg.decl_ctx with
ctx_enums =
EnumName.Map.remove Expr.option_enum prg.decl_ctx.ctx_enums;
};
}
in
(* Then we replace all hardcoded types and expressions with the monomorphized (* Then we replace all hardcoded types and expressions with the monomorphized
instances *) instances *)
let prg = let ctx_enums =
{
prg with
decl_ctx =
{
prg.decl_ctx with
ctx_enums =
EnumName.Map.map EnumName.Map.map
(EnumConstructor.Map.map (EnumConstructor.Map.map (monomorphize_typ monomorphized_instances))
(monomorphize_typ monomorphized_instances)) ctx_enums
prg.decl_ctx.ctx_enums; in
ctx_structs = let ctx_structs =
StructName.Map.map StructName.Map.map
(StructField.Map.map (monomorphize_typ monomorphized_instances)) (StructField.Map.map (monomorphize_typ monomorphized_instances))
prg.decl_ctx.ctx_structs; ctx_structs
};
}
in in
(* Then we augment the [decl_ctx] with the monomorphized instances *) (* Then we augment the [decl_ctx] with the monomorphized instances *)
let prg = let ctx_enums =
{
prg with
decl_ctx =
{
prg.decl_ctx with
ctx_enums =
Type.Map.fold Type.Map.fold
(fun _ (option_instance : option_instance) (ctx_enums : enum_ctx) -> (fun _ (option_instance : option_instance) (ctx_enums : enum_ctx) ->
EnumName.Map.add option_instance.name EnumName.Map.add option_instance.name
@ -350,24 +316,22 @@ let program (prg : typed program) :
(monomorphize_typ monomorphized_instances (monomorphize_typ monomorphized_instances
(option_instance.some_typ, Pos.no_pos)))) (option_instance.some_typ, Pos.no_pos))))
ctx_enums) ctx_enums)
monomorphized_instances.options prg.decl_ctx.ctx_enums; monomorphized_instances.options ctx_enums
ctx_structs = in
let ctx_structs =
Type.Map.fold Type.Map.fold
(fun _ (tuple_instance : tuple_instance) (fun _ (tuple_instance : tuple_instance) (ctx_structs : struct_ctx) ->
(ctx_structs : struct_ctx) ->
StructName.Map.add tuple_instance.name StructName.Map.add tuple_instance.name
(List.fold_left (List.fold_left
(fun acc (field, typ) -> (fun acc (field, typ) ->
StructField.Map.add field StructField.Map.add field
(monomorphize_typ monomorphized_instances (monomorphize_typ monomorphized_instances (typ, Pos.no_pos))
(typ, Pos.no_pos))
acc) acc)
StructField.Map.empty tuple_instance.fields) StructField.Map.empty tuple_instance.fields)
ctx_structs) ctx_structs)
monomorphized_instances.tuples monomorphized_instances.tuples
(Type.Map.fold (Type.Map.fold
(fun _ (array_instance : array_instance) (fun _ (array_instance : array_instance) (ctx_structs : struct_ctx) ->
(ctx_structs : struct_ctx) ->
StructName.Map.add array_instance.name StructName.Map.add array_instance.name
(StructField.Map.add array_instance.content_field (StructField.Map.add array_instance.content_field
( TArray ( TArray
@ -377,28 +341,15 @@ let program (prg : typed program) :
(StructField.Map.singleton array_instance.len_field (StructField.Map.singleton array_instance.len_field
(TLit TInt, Pos.no_pos))) (TLit TInt, Pos.no_pos)))
ctx_structs) ctx_structs)
monomorphized_instances.arrays prg.decl_ctx.ctx_structs); monomorphized_instances.arrays ctx_structs)
};
}
in in
let decl_ctx = { decl_ctx with ctx_structs; ctx_enums } in
let code_items = let code_items =
Bindlib.unbox Bindlib.unbox
@@ Scope.map @@ Scope.map_exprs prg.code_items
~f:(fun code_item -> ~typ:(monomorphize_typ monomorphized_instances)
match code_item with ~varf:Fun.id
| Topdef (name, typ, e) -> Bindlib.box (Topdef (name, typ, e))
| ScopeDef (name, body) ->
let s_var, scope_body = Bindlib.unbind body.scope_body_expr in
Bindlib.box_apply
(fun scope_body_expr ->
ScopeDef (name, { body with scope_body_expr }))
(Bindlib.bind_var s_var
(Scope.map_exprs_in_lets ~varf:Fun.id
~transform_types:(monomorphize_typ monomorphized_instances)
~f:(monomorphize_expr monomorphized_instances) ~f:(monomorphize_expr monomorphized_instances)
scope_body)))
~varf:Fun.id prg.code_items
in in
( { prg with code_items }, ( { prg with decl_ctx; code_items },
Scopelang.Dependency.check_type_cycles prg.decl_ctx.ctx_structs Scopelang.Dependency.check_type_cycles ctx_structs ctx_enums )
prg.decl_ctx.ctx_enums )

View File

@ -448,8 +448,7 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
| ERaise exc -> | ERaise exc ->
Format.fprintf fmt "raise@ %a" format_exception (exc, Expr.pos e) Format.fprintf fmt "raise@ %a" format_exception (exc, Expr.pos e)
| ECatch { body; exn; handler } -> | ECatch { body; exn; handler } ->
Format.fprintf fmt Format.fprintf fmt "@[<hv>@[<hov 2>try@ %a@]@ with@]@ @[<hov 2>%a@ ->@ %a@]"
"@,@[<hv>@[<hov 2>try@ %a@]@ with@]@ @[<hov 2>%a@ ->@ %a@]"
format_with_parens body format_exception format_with_parens body format_exception
(exn, Expr.pos e) (exn, Expr.pos e)
format_with_parens handler format_with_parens handler
@ -569,48 +568,57 @@ let rename_vars e =
(rename_vars ~exclude:ocaml_keywords ~reset_context_for_closed_terms:true (rename_vars ~exclude:ocaml_keywords ~reset_context_for_closed_terms:true
~skip_constant_binders:true ~constant_binder_name:(Some "_") e)) ~skip_constant_binders:true ~constant_binder_name:(Some "_") e))
let format_expr ctx fmt e = format_expr ctx fmt (rename_vars e) let format_expr ctx fmt e =
Format.pp_open_vbox fmt 0;
format_expr ctx fmt (rename_vars e);
Format.pp_close_box fmt ()
let rec format_scope_body_expr let format_scope_body_expr
(ctx : decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(scope_lets : 'm Ast.expr scope_body_expr) : unit = (scope_lets : 'm Ast.expr scope_body_expr) : unit =
match scope_lets with Format.pp_open_vbox fmt 0;
| Result e -> format_expr ctx fmt e let last_e =
| ScopeLet scope_let -> BoundList.iter
let scope_let_var, scope_let_next = ~f:(fun scope_let_var scope_let ->
Bindlib.unbind scope_let.scope_let_next Format.fprintf fmt "@[<hv>@[<hov 2>let %a: %a =@ %a@ @]in@]@,"
format_var scope_let_var format_typ scope_let.scope_let_typ
(format_expr ctx) scope_let.scope_let_expr)
scope_lets
in in
Format.fprintf fmt "@[<hov 2>let %a: %a = %a in@]@\n%a" format_var format_expr ctx fmt last_e;
scope_let_var format_typ scope_let.scope_let_typ (format_expr ctx) Format.pp_close_box fmt ()
scope_let.scope_let_expr
(format_scope_body_expr ctx)
scope_let_next
let format_code_items let format_code_items
(ctx : decl_ctx) (ctx : decl_ctx)
(fmt : Format.formatter) (fmt : Format.formatter)
(code_items : 'm Ast.expr code_item_list) : (code_items : 'm Ast.expr code_item_list) :
('m Ast.expr Var.t * 'm Ast.expr code_item) String.Map.t = ('m Ast.expr Var.t * 'm Ast.expr code_item) String.Map.t =
Scope.fold_left Format.pp_open_vbox fmt 0;
let var_bindings, () =
BoundList.fold_left
~f:(fun bnd item var -> ~f:(fun bnd item var ->
match item with match item with
| Topdef (name, typ, e) -> | Topdef (name, typ, e) ->
Format.fprintf fmt "@\n@\n@[<hov 2>let %a : %a =@\n%a@]" format_var var Format.fprintf fmt "@,@[<v 2>@[<hov 2>let %a : %a =@]@ %a@]@,"
format_typ typ (format_expr ctx) e; format_var var format_typ typ (format_expr ctx) e;
String.Map.add (TopdefName.to_string name) (var, item) bnd String.Map.add (TopdefName.to_string name) (var, item) bnd
| ScopeDef (name, body) -> | ScopeDef (name, body) ->
let scope_input_var, scope_body_expr = let scope_input_var, scope_body_expr =
Bindlib.unbind body.scope_body_expr Bindlib.unbind body.scope_body_expr
in in
Format.fprintf fmt "@\n@\n@[<hov 2>let %a (%a: %a.t) : %a.t =@\n%a@]" Format.fprintf fmt
format_var var format_var scope_input_var format_to_module_name "@,@[<hv 2>@[<hov 2>let %a (%a: %a.t) : %a.t =@]@ %a@]@," format_var
var format_var scope_input_var format_to_module_name
(`Sname body.scope_body_input_struct) format_to_module_name (`Sname body.scope_body_input_struct) format_to_module_name
(`Sname body.scope_body_output_struct) (`Sname body.scope_body_output_struct)
(format_scope_body_expr ctx) (format_scope_body_expr ctx)
scope_body_expr; scope_body_expr;
String.Map.add (ScopeName.to_string name) (var, item) bnd) String.Map.add (ScopeName.to_string name) (var, item) bnd)
~init:String.Map.empty code_items ~init:String.Map.empty code_items
in
Format.pp_close_box fmt ();
var_bindings
let format_scope_exec let format_scope_exec
(ctx : decl_ctx) (ctx : decl_ctx)

View File

@ -574,7 +574,7 @@ let rec translate_scope_body_expr
(func_dict : ('m L.expr, A.FuncName.t) Var.Map.t) (func_dict : ('m L.expr, A.FuncName.t) Var.Map.t)
(scope_expr : 'm L.expr scope_body_expr) : A.block = (scope_expr : 'm L.expr scope_body_expr) : A.block =
match scope_expr with match scope_expr with
| Result e -> | Last e ->
let block, new_e = let block, new_e =
translate_expr translate_expr
{ {
@ -587,8 +587,8 @@ let rec translate_scope_body_expr
e e
in in
block @ [A.SReturn (Mark.remove new_e), Mark.get new_e] block @ [A.SReturn (Mark.remove new_e), Mark.get new_e]
| ScopeLet scope_let -> | Cons (scope_let, next_bnd) ->
let let_var, scope_let_next = Bindlib.unbind scope_let.scope_let_next in let let_var, scope_let_next = Bindlib.unbind next_bnd in
let let_var_id = let let_var_id =
A.VarName.fresh (Bindlib.name_of let_var, scope_let.scope_let_pos) A.VarName.fresh (Bindlib.name_of let_var, scope_let.scope_let_pos)
in in
@ -637,8 +637,8 @@ let rec translate_scope_body_expr
let translate_program ~(config : translation_config) (p : 'm L.program) : let translate_program ~(config : translation_config) (p : 'm L.program) :
A.program = A.program =
let _, _, rev_items = let (_, _, rev_items), () =
Scope.fold_left BoundList.fold_left
~f:(fun (func_dict, var_dict, rev_items) code_item var -> ~f:(fun (func_dict, var_dict, rev_items) code_item var ->
match code_item with match code_item with
| ScopeDef (name, body) -> | ScopeDef (name, body) ->

View File

@ -0,0 +1,118 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Definitions
type ('e, 'elt, 'last) t = ('e, 'elt, 'last) bound_list =
| Last of 'last
| Cons of 'elt * ('e, ('e, 'elt, 'last) t) binder
let rec last = function
| Last e -> e
| Cons (_, bnd) ->
let _, next = Bindlib.unbind bnd in
last next
let rec iter ~f = function
| Last l -> l
| Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
f var item;
iter ~f next
let rec find ~f = function
| Last _ -> raise Not_found
| Cons (item, next_bind) -> (
match f item with
| Some r -> r
| None ->
let _, next = Bindlib.unbind next_bind in
find ~f next)
let rec fold_left ~f ~init = function
| Last l -> init, l
| Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
fold_left ~f ~init:(f init item var) next
let rec fold_right ~f ~init = function
| Last l -> init l
| Cons (item, next_bind) ->
let var_next, next = Bindlib.unbind next_bind in
let result_next = fold_right ~f ~init next in
f item var_next result_next
let rec fold_lr ~top ~down ~bottom ~up = function
| Last l -> bottom l top
| Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
let top = down var item top in
let bottom = fold_lr ~down ~up ~top ~bottom next in
up var item bottom
let rec map ~f ~last = function
| Last l -> Bindlib.box_apply (fun l -> Last l) (last l)
| Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
let var, item = f var item in
let next_bind = Bindlib.bind_var var (map ~f ~last next) in
Bindlib.box_apply2
(fun item next_bind -> Cons (item, next_bind))
item next_bind
let rec fold_map ~f ~last ~init:ctx = function
| Last l ->
let ret, l = last ctx l in
ret, Bindlib.box_apply (fun l -> Last l) l
| Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
let ctx, var, item = f ctx var item in
let ctx, next = fold_map ~f ~last ~init:ctx next in
let next_bind = Bindlib.bind_var var next in
( ctx,
Bindlib.box_apply2
(fun item next_bind -> Cons (item, next_bind))
item next_bind )
let rec fold_left2 ~f ~init a b =
match a, b with
| Last l1, Last l2 -> init, (l1, l2)
| Cons (item1, next_bind1), Cons (item2, next_bind2) ->
let var, next1, next2 = Bindlib.unbind2 next_bind1 next_bind2 in
fold_left2 ~f ~init:(f init item1 item2 var) next1 next2
| _ -> invalid_arg "fold_left2"
let rec equal ~f ~last a b =
match a, b with
| Last l1, Last l2 -> last l1 l2
| Cons (item1, next_bind1), Cons (item2, next_bind2) ->
f item1 item2
&&
let _, next1, next2 = Bindlib.unbind2 next_bind1 next_bind2 in
equal ~f ~last next1 next2
| _ -> false
let rec compare ~f ~last a b =
match a, b with
| Last l1, Last l2 -> last l1 l2
| Cons (item1, next_bind1), Cons (item2, next_bind2) -> (
match f item1 item2 with
| 0 ->
let _, next1, next2 = Bindlib.unbind2 next_bind1 next_bind2 in
compare ~f ~last next1 next2
| n -> n)
| Last _, Cons _ -> -1
| Cons _, Last _ -> 1

View File

@ -0,0 +1,92 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020-2022 Inria,
contributor: Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Bound lists are non-empty linked lists where each element is a binder onto
the next. They are useful for ordered program definitions, like nested
let-ins.
[let a = e1 in e2] is thus represented as [Cons (e1, {a. Last e2})].
The following provides a few utility functions for their traversal and
manipulation. In particular, [map] functions take care of unbinding, then
properly rebinding the variables. *)
open Definitions
type ('e, 'elt, 'last) t = ('e, 'elt, 'last) bound_list =
| Last of 'last
| Cons of 'elt * ('e, ('e, 'elt, 'last) t) binder
val last : (_, _, 'a) t -> 'a
val iter : f:('e Var.t -> 'elt -> unit) -> ('e, 'elt, 'last) t -> 'last
val find : f:('elt -> 'a option) -> (_, 'elt, _) t -> 'a
val fold_left :
f:('acc -> 'elt -> 'e Var.t -> 'acc) ->
init:'acc ->
('e, 'elt, 'last) t ->
'acc * 'last
val fold_left2 :
f:('acc -> 'elt1 -> 'elt2 -> 'e Var.t -> 'acc) ->
init:'acc ->
('e, 'elt1, 'last1) t ->
('e, 'elt2, 'last2) t ->
'acc * ('last1 * 'last2)
val fold_right :
f:('elt -> 'e Var.t -> 'acc -> 'acc) ->
init:('last -> 'acc) ->
('e, 'elt, 'last) t ->
'acc
val fold_lr :
top:'dacc ->
down:('e Var.t -> 'elt -> 'dacc -> 'dacc) ->
bottom:('last -> 'dacc -> 'uacc) ->
up:('e Var.t -> 'elt -> 'uacc -> 'uacc) ->
('e, 'elt, 'last) t ->
'uacc
(** Bi-directional fold: [down] accumulates downwards, starting from [top]; upon
reaching [last], [bottom] is called; then [up] accumulates on the way back
up *)
val map :
f:('e1 Var.t -> 'elt1 -> 'e2 Var.t * 'elt2 Bindlib.box) ->
last:('last1 -> 'last2 Bindlib.box) ->
('e1, 'elt1, 'last1) t ->
('e2, 'elt2, 'last2) t Bindlib.box
val fold_map :
f:('ctx -> 'e1 Var.t -> 'elt1 -> 'ctx * 'e2 Var.t * 'elt2 Bindlib.box) ->
last:('ctx -> 'last1 -> 'ret * 'last2 Bindlib.box) ->
init:'ctx ->
('e1, 'elt1, 'last1) t ->
'ret * ('e2, 'elt2, 'last2) t Bindlib.box
val equal :
f:('elt -> 'elt -> bool) ->
last:('last -> 'last -> bool) ->
(('e, 'elt, 'last) t as 'l) ->
'l ->
bool
val compare :
f:('elt -> 'elt -> int) ->
last:('last -> 'last -> int) ->
(('e, 'elt, 'last) t as 'l) ->
'l ->
int

View File

@ -615,6 +615,12 @@ type ('e, 'b) mbinder = (('a, 'm) naked_gexpr, 'b) Bindlib.mbinder
Note that this structure is at the moment only relevant for [dcalc] and Note that this structure is at the moment only relevant for [dcalc] and
[lcalc], as [scopelang] has its own scope structure, as the name implies. *) [lcalc], as [scopelang] has its own scope structure, as the name implies. *)
(** A linked list, but with a binder for each element into the next:
[x := let a = e1 in e2] is thus [Cons (e1, {a. Cons (e2, {x. Nil})})] *)
type ('e, 'elt, 'last) bound_list =
| Last of 'last
| Cons of 'elt * ('e, ('e, 'elt, 'last) bound_list) binder
(** This kind annotation signals that the let-binding respects a structural (** This kind annotation signals that the let-binding respects a structural
invariant. These invariants concern the shape of the expression in the invariant. These invariants concern the shape of the expression in the
let-binding, and are documented below. *) let-binding, and are documented below. *)
@ -632,21 +638,17 @@ type 'e scope_let = {
scope_let_kind : scope_let_kind; scope_let_kind : scope_let_kind;
scope_let_typ : typ; scope_let_typ : typ;
scope_let_expr : 'e; scope_let_expr : 'e;
scope_let_next : ('e, 'e scope_body_expr) binder;
(* todo ? Factorise the code_item _list type below and use it here *)
scope_let_pos : Pos.t; scope_let_pos : Pos.t;
} }
constraint 'e = ('a any, _) gexpr constraint 'e = ('a any, _) gexpr
(** This type is parametrized by the expression type so it can be reused in (** This type is parametrized by the expression type so it can be reused in
later intermediate representations. *) later intermediate representations. *)
type 'e scope_body_expr = ('e, 'e scope_let, 'e) bound_list
constraint 'e = ('a any, _) gexpr
(** A scope let-binding has all the information necessary to make a proper (** A scope let-binding has all the information necessary to make a proper
let-binding expression, plus an annotation for the kind of the let-binding let-binding expression, plus an annotation for the kind of the let-binding
that comes from the compilation of a {!module: Scopelang.Ast} statement. *) that comes from the compilation of a {!module: Scopelang.Ast} statement. *)
and 'e scope_body_expr =
| Result of 'e
| ScopeLet of 'e scope_let
constraint 'e = ('a any, _) gexpr
type 'e scope_body = { type 'e scope_body = {
scope_body_input_struct : StructName.t; scope_body_input_struct : StructName.t;
@ -663,13 +665,7 @@ type 'e code_item =
| ScopeDef of ScopeName.t * 'e scope_body | ScopeDef of ScopeName.t * 'e scope_body
| Topdef of TopdefName.t * typ * 'e | Topdef of TopdefName.t * typ * 'e
(** A chained list, but with a binder for each element into the next: type 'e code_item_list = ('e, 'e code_item, unit) bound_list
[x := let a
= e1 in e2] is thus [Cons (e1, {a. Cons (e2, {x. Nil})})] *)
type 'e code_item_list =
| Nil
| Cons of 'e code_item * ('e, 'e code_item_list) binder
type struct_ctx = typ StructField.Map.t StructName.Map.t type struct_ctx = typ StructField.Map.t StructName.Map.t
type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t

View File

@ -778,6 +778,8 @@ let rec free_vars : ('a, 't) gexpr -> ('a, 't) gexpr Var.Set.t = function
let vs, body = Bindlib.unmbind binder in let vs, body = Bindlib.unmbind binder in
Array.fold_right Var.Set.remove vs (free_vars body) Array.fold_right Var.Set.remove vs (free_vars body)
| e -> shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty | e -> shallow_fold (fun e -> Var.Set.union (free_vars e)) e Var.Set.empty
(* Could also be done with [rebox] followed by [Bindlib.free_vars], if that
returned more than a context *)
(* This function is first defined in [Print], only for dependency reasons *) (* This function is first defined in [Print], only for dependency reasons *)
let skip_wrappers : type a. (a, 'm) gexpr -> (a, 'm) gexpr = Print.skip_wrappers let skip_wrappers : type a. (a, 'm) gexpr -> (a, 'm) gexpr = Print.skip_wrappers

View File

@ -944,7 +944,10 @@ let interpret_program_lcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list
in in
let to_interpret = let to_interpret =
Expr.make_app (Expr.box e) Expr.make_app (Expr.box e)
[Expr.estruct ~name:s_in ~fields:application_term mark_e] [
Expr.estruct ~name:s_in ~fields:application_term
(Expr.map_ty (fun (_, pos) -> TStruct s_in, pos) mark_e);
]
[TStruct s_in, Expr.pos e] [TStruct s_in, Expr.pos e]
(Expr.pos e) (Expr.pos e)
in in
@ -996,7 +999,10 @@ let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list
in in
let to_interpret = let to_interpret =
Expr.make_app (Expr.box e) Expr.make_app (Expr.box e)
[Expr.estruct ~name:s_in ~fields:application_term mark_e] [
Expr.estruct ~name:s_in ~fields:application_term
(Expr.map_ty (fun (_, pos) -> TStruct s_in, pos) mark_e);
]
[TStruct s_in, Expr.pos e] [TStruct s_in, Expr.pos e]
(Expr.pos e) (Expr.pos e)
in in

View File

@ -380,8 +380,7 @@ let optimize_expr :
optimize_expr { decl_ctx } e optimize_expr { decl_ctx } e
let optimize_program (p : 'm program) : 'm program = let optimize_program (p : 'm program) : 'm program =
Bindlib.unbox Program.map_exprs ~f:(optimize_expr p.decl_ctx) ~varf:(fun v -> v) p
(Program.map_exprs ~f:(optimize_expr p.decl_ctx) ~varf:(fun v -> v) p)
let test_iota_reduction_1 () = let test_iota_reduction_1 () =
let x = Var.make "x" in let x = Var.make "x" in

View File

@ -563,7 +563,9 @@ module ExprGen (C : EXPR_PARAM) = struct
(Format.pp_print_list ~pp_sep:Format.pp_print_space (Format.pp_print_list ~pp_sep:Format.pp_print_space
(fun fmt (x, tau) -> (fun fmt (x, tau) ->
match tau with match tau with
| TLit TUnit, _ -> punctuation fmt "("; punctuation fmt ")" | TLit TUnit, _ ->
punctuation fmt "(";
punctuation fmt ")"
| _ -> | _ ->
punctuation fmt "("; punctuation fmt "(";
Format.pp_open_hvbox fmt 2; Format.pp_open_hvbox fmt 2;
@ -710,17 +712,16 @@ module ExprGen (C : EXPR_PARAM) = struct
| EAbs { binder; tys; _ }, _ -> | EAbs { binder; tys; _ }, _ ->
let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in
let expr = exprb bnd_ctx in let expr = exprb bnd_ctx in
let pp_args fmt = match tys with let pp_args fmt =
| [TLit TUnit, _] -> () match tys with
| [(TLit TUnit, _)] -> ()
| _ -> | _ ->
Format.pp_print_seq ~pp_sep:Format.pp_print_space var fmt Format.pp_print_seq ~pp_sep:Format.pp_print_space var fmt
(Array.to_seq xs); (Array.to_seq xs);
Format.pp_print_space fmt () Format.pp_print_space fmt ()
in in
Format.fprintf fmt "@[<hov 2>%a %t@ %t%a@ %a@]" punctuation Format.fprintf fmt "@[<hov 2>%a %t@ %t%a@ %a@]" punctuation "|"
"|" pp_cons_name pp_cons_name pp_args punctuation "" (rhs expr) body
pp_args
punctuation "" (rhs expr) body
| e -> | e ->
Format.fprintf fmt "@[<hov 2>%a %t@ %a@ %a@]" punctuation "|" Format.fprintf fmt "@[<hov 2>%a %t@ %a@ %a@]" punctuation "|"
pp_cons_name punctuation "" (rhs exprc) e)) pp_cons_name punctuation "" (rhs exprc) e))
@ -782,30 +783,22 @@ let scope_let_kind ?debug:(_debug = true) _ctx fmt k =
| DestructuringSubScopeResults -> keyword fmt "sub_get" | DestructuringSubScopeResults -> keyword fmt "sub_get"
| Assertion -> keyword fmt "assert" | Assertion -> keyword fmt "assert"
let[@ocamlformat "disable"] rec let[@ocamlformat "disable"]
scope_body_expr ?(debug = false) ctx fmt b : unit = scope_body_expr ?(debug = false) ctx fmt b : unit =
match b with let print_scope_let x sl =
| Result e -> Format.fprintf fmt "%a %a" keyword "return" (expr ~debug ()) e
| ScopeLet
{
scope_let_kind = kind;
scope_let_typ;
scope_let_expr;
scope_let_next;
_;
} ->
let x, next = Bindlib.unbind scope_let_next in
Format.fprintf fmt Format.fprintf fmt
"@[<hv 2>@[<hov 4>%a %a %a %a@ %a@ %a@]@ %a@;<1 -2>%a@]@,%a" "@[<hv 2>@[<hov 4>%a %a %a %a@ %a@ %a@]@ %a@;<1 -2>%a@]@,"
keyword "let" keyword "let"
(scope_let_kind ~debug ctx) kind (scope_let_kind ~debug ctx) sl.scope_let_kind
(if debug then var_debug else var) x (if debug then var_debug else var) x
punctuation ":" punctuation ":"
(typ ctx) scope_let_typ (typ ctx) sl.scope_let_typ
punctuation "=" punctuation "="
(expr ~debug ()) scope_let_expr (expr ~debug ()) sl.scope_let_expr
keyword "in" keyword "in"
(scope_body_expr ~debug ctx) next in
let last = BoundList.iter ~f:print_scope_let b in
Format.fprintf fmt "%a %a" keyword "return" (expr ~debug ()) last
let scope_body ?(debug = false) ctx fmt (n, l) : unit = let scope_body ?(debug = false) ctx fmt (n, l) : unit =
let { let {
@ -936,16 +929,12 @@ let code_item ?(debug = false) ?name decl_ctx fmt c =
"let topval" TopdefName.format n op_style ":" (typ decl_ctx) ty op_style "let topval" TopdefName.format n op_style ":" (typ decl_ctx) ty op_style
"=" (expr ~debug ()) e "=" (expr ~debug ()) e
let rec code_item_list ?(debug = false) decl_ctx fmt c = let code_item_list ?(debug = false) decl_ctx fmt c =
match c with BoundList.iter c ~f:(fun x item ->
| Nil -> () code_item ~debug
| Cons (c, b) -> ~name:(Format.asprintf "%a" var_debug x)
let x, cl = Bindlib.unbind b in decl_ctx fmt item;
Format.fprintf fmt "%a @.%a" Format.pp_print_newline fmt ())
(code_item ~debug ~name:(Format.asprintf "%a" var_debug x) decl_ctx)
c
(code_item_list ~debug decl_ctx)
cl
let program ?(debug = false) fmt p = let program ?(debug = false) fmt p =
decl_ctx ~debug p.decl_ctx fmt p.decl_ctx; decl_ctx ~debug p.decl_ctx fmt p.decl_ctx;

View File

@ -17,16 +17,37 @@
open Definitions open Definitions
let map_exprs ~f ~varf { code_items; decl_ctx; lang; module_name } = let map_decl_ctx ~f ctx =
{
ctx with
ctx_enums = EnumName.Map.map (EnumConstructor.Map.map f) ctx.ctx_enums;
ctx_structs = StructName.Map.map (StructField.Map.map f) ctx.ctx_structs;
ctx_topdefs = TopdefName.Map.map f ctx.ctx_topdefs;
}
let map_exprs ?typ ~f ~varf { code_items; decl_ctx; lang; module_name } =
let boxed_prg =
Bindlib.box_apply Bindlib.box_apply
(fun code_items -> { code_items; decl_ctx; lang; module_name }) (fun code_items ->
(Scope.map_exprs ~f ~varf code_items) let decl_ctx =
match typ with None -> decl_ctx | Some f -> map_decl_ctx ~f decl_ctx
in
{ code_items; decl_ctx; lang; module_name })
(Scope.map_exprs ?typ ~f ~varf code_items)
in
assert (Bindlib.is_closed boxed_prg);
Bindlib.unbox boxed_prg
let fold_left_exprs ~f ~init { code_items; _ } = let fold_left ~f ~init { code_items; _ } =
Scope.fold_left ~f:(fun acc e _ -> f acc e) ~init code_items fst @@ BoundList.fold_left ~f:(fun acc e _ -> f acc e) ~init code_items
let fold_right_exprs ~f ~init { code_items; _ } = let fold_exprs ~f ~init prg = Scope.fold_exprs ~f ~init prg.code_items
Scope.fold_right ~f:(fun e _ acc -> f e acc) ~init code_items
let fold_right ~f ~init { code_items; _ } =
BoundList.fold_right
~f:(fun e _ acc -> f e acc)
~init:(fun () -> init)
code_items
let empty_ctx = let empty_ctx =
{ {
@ -42,56 +63,25 @@ let empty_ctx =
let get_scope_body { code_items; _ } scope = let get_scope_body { code_items; _ } scope =
match match
Scope.fold_left ~init:None BoundList.fold_left ~init:None
~f:(fun acc item _ -> ~f:(fun acc item _ ->
match item with match item with
| ScopeDef (name, body) when ScopeName.equal scope name -> Some body | ScopeDef (name, body) when ScopeName.equal scope name -> Some body
| _ -> acc) | _ -> acc)
code_items code_items
with with
| None -> raise Not_found | None, _ -> raise Not_found
| Some body -> body | Some body, _ -> body
let untype : 'm. ('a, 'm) gexpr program -> ('a, untyped) gexpr program = let untype : 'm. ('a, 'm) gexpr program -> ('a, untyped) gexpr program =
fun prg -> Bindlib.unbox (map_exprs ~f:Expr.untype ~varf:Var.translate prg) fun prg -> map_exprs ~f:Expr.untype ~varf:Var.translate prg
let rec find_scope name vars = function let find_scope name =
| Nil -> raise Not_found BoundList.find ~f:(function
| Cons (ScopeDef (n, body), _) when ScopeName.equal name n -> | ScopeDef (n, body) when ScopeName.equal name n -> Some body
List.rev vars, body | _ -> None)
| Cons (_, next_bind) ->
let var, next = Bindlib.unbind next_bind in
find_scope name (var :: vars) next
let rec all_scopes code_item_list =
match code_item_list with
| Nil -> []
| Cons (ScopeDef (n, _), next_bind) ->
let _var, next = Bindlib.unbind next_bind in
n :: all_scopes next
| Cons (_, next_bind) ->
let _var, next = Bindlib.unbind next_bind in
all_scopes next
let to_expr p main_scope = let to_expr p main_scope =
let _, main_scope_body = find_scope main_scope [] p.code_items in let res = Scope.unfold p.decl_ctx p.code_items main_scope in
let res =
Scope.unfold p.decl_ctx p.code_items
(Scope.get_body_mark main_scope_body)
(ScopeName main_scope)
in
Expr.Box.assert_closed (Expr.Box.lift res); Expr.Box.assert_closed (Expr.Box.lift res);
res res
let equal p p' =
(* TODO: include toplevel definitions in this program comparison. *)
let ss = all_scopes p.code_items in
let ss' = all_scopes p'.code_items in
List.length ss = List.length ss'
&& ListLabels.for_all2 ss ss' ~f:(fun s s' ->
ScopeName.equal s s'
&&
let e1 = Expr.unbox @@ to_expr p s in
let e2 = Expr.unbox @@ to_expr p' s in
Expr.equal e1 e2)

View File

@ -23,16 +23,22 @@ val empty_ctx : decl_ctx
(** {2 Transformations} *) (** {2 Transformations} *)
val map_decl_ctx : f:(typ -> typ) -> decl_ctx -> decl_ctx
val map_exprs : val map_exprs :
?typ:(typ -> typ) ->
f:('expr1 -> 'expr2 boxed) -> f:('expr1 -> 'expr2 boxed) ->
varf:('expr1 Var.t -> 'expr2 Var.t) -> varf:('expr1 Var.t -> 'expr2 Var.t) ->
'expr1 program -> 'expr1 program ->
'expr2 program Bindlib.box 'expr2 program
(** If [typ] is specified, definitions in [decl_ctx] are also processed *)
val fold_left_exprs : val fold_left :
f:('a -> 'expr code_item -> 'a) -> init:'a -> 'expr program -> 'a f:('a -> 'expr code_item -> 'a) -> init:'a -> 'expr program -> 'a
val fold_right_exprs : val fold_exprs : f:('a -> 'expr -> typ -> 'a) -> init:'a -> 'expr program -> 'a
val fold_right :
f:('expr code_item -> 'a -> 'a) -> init:'a -> 'expr program -> 'a f:('expr code_item -> 'a -> 'a) -> init:'a -> 'expr program -> 'a
val get_scope_body : val get_scope_body :
@ -45,6 +51,4 @@ val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed
corresponding to the main program and returning the main scope as a corresponding to the main program and returning the main scope as a
function. *) function. *)
val equal : val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body
(('a any, _) gexpr as 'e) program -> (('a any, _) gexpr as 'e) program -> bool
(** Warning / todo: only compares program scopes at the moment *)

View File

@ -18,151 +18,80 @@
open Catala_utils open Catala_utils
open Definitions open Definitions
let rec fold_left_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result _ -> init
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
fold_left_lets ~f ~init:(f init scope_let var) next
let rec fold_right_lets ~f ~init scope_body_expr =
match scope_body_expr with
| Result result -> init result
| ScopeLet scope_let ->
let var, next = Bindlib.unbind scope_let.scope_let_next in
let next_result = fold_right_lets ~f ~init next in
f scope_let var next_result
let map_exprs_in_lets : let map_exprs_in_lets :
?transform_types:(typ -> typ) -> ?typ:(typ -> typ) ->
f:('expr1 -> 'expr2 boxed) -> f:('expr1 -> 'expr2 boxed) ->
varf:('expr1 Var.t -> 'expr2 Var.t) -> varf:('expr1 Var.t -> 'expr2 Var.t) ->
'expr1 scope_body_expr -> 'expr1 scope_body_expr ->
'expr2 scope_body_expr Bindlib.box = 'expr2 scope_body_expr Bindlib.box =
fun ?(transform_types = Fun.id) ~f ~varf scope_body_expr -> fun ?(typ = Fun.id) ~f ~varf scope_body_expr ->
fold_right_lets let f e = Expr.Box.lift (f e) in
~f:(fun scope_let var_next acc -> BoundList.map ~last:f
Bindlib.box_apply2 ~f:(fun v scope_let ->
(fun scope_let_next scope_let_expr -> ( varf v,
ScopeLet Bindlib.box_apply
(fun scope_let_expr ->
{ {
scope_let with scope_let with
scope_let_next;
scope_let_expr; scope_let_expr;
scope_let_typ = transform_types scope_let.scope_let_typ; scope_let_typ = typ scope_let.scope_let_typ;
}) })
(Bindlib.bind_var (varf var_next) acc) (f scope_let.scope_let_expr) ))
(Expr.Box.lift (f scope_let.scope_let_expr)))
~init:(fun res ->
Bindlib.box_apply (fun res -> Result res) (Expr.Box.lift (f res)))
scope_body_expr scope_body_expr
let rec fold_left ~f ~init = function let map_exprs ?(typ = Fun.id) ~f ~varf scopes =
| Nil -> init let f v = function
| Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
fold_left ~f ~init:(f init item var) next
let rec fold_right ~f ~init = function
| Nil -> init
| Cons (item, next_bind) ->
let var_next, next = Bindlib.unbind next_bind in
let result_next = fold_right ~f ~init next in
f item var_next result_next
let rec map ~f ~varf = function
| Nil -> Bindlib.box Nil
| Cons (item, next_bind) ->
let item = f item in
let next_bind =
let var, next = Bindlib.unbind next_bind in
Bindlib.bind_var (varf var) (map ~f ~varf next)
in
Bindlib.box_apply2
(fun item next_bind -> Cons (item, next_bind))
item next_bind
let rec map_ctx ~f ~varf ctx = function
| Nil -> Bindlib.box Nil
| Cons (item, next_bind) ->
let ctx, item = f ctx item in
let next_bind =
let var, next = Bindlib.unbind next_bind in
Bindlib.bind_var (varf var) (map_ctx ~f ~varf ctx next)
in
Bindlib.box_apply2
(fun item next_bind -> Cons (item, next_bind))
item next_bind
let rec fold_map ~f ~varf ctx = function
| Nil -> ctx, Bindlib.box Nil
| Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
let ctx, item = f ctx var item in
let ctx, next = fold_map ~f ~varf ctx next in
let next_bind = Bindlib.bind_var (varf var) next in
( ctx,
Bindlib.box_apply2
(fun item next_bind -> Cons (item, next_bind))
item next_bind )
let map_exprs ~f ~varf scopes =
let f = function
| ScopeDef (name, body) -> | ScopeDef (name, body) ->
let scope_input_var, scope_lets = Bindlib.unbind body.scope_body_expr in let scope_input_var, scope_lets = Bindlib.unbind body.scope_body_expr in
let new_body_expr = map_exprs_in_lets ~f ~varf scope_lets in let new_body_expr = map_exprs_in_lets ~typ ~f ~varf scope_lets in
let new_body_expr = let new_body_expr =
Bindlib.bind_var (varf scope_input_var) new_body_expr Bindlib.bind_var (varf scope_input_var) new_body_expr
in in
( varf v,
Bindlib.box_apply Bindlib.box_apply
(fun scope_body_expr -> ScopeDef (name, { body with scope_body_expr })) (fun scope_body_expr ->
new_body_expr ScopeDef (name, { body with scope_body_expr }))
| Topdef (name, typ, expr) -> new_body_expr )
| Topdef (name, ty, expr) ->
( varf v,
Bindlib.box_apply Bindlib.box_apply
(fun e -> Topdef (name, typ, e)) (fun e -> Topdef (name, typ ty, e))
(Expr.Box.lift (f expr)) (Expr.Box.lift (f expr)) )
in in
map ~f ~varf scopes BoundList.map ~f ~last:Bindlib.box scopes
(* TODO: compute the expected body expr arrow type manually instead of [TAny] let fold_exprs ~f ~init scopes =
for double-checking types ? *) let f acc def _ =
let rec get_body_expr_mark = function match def with
| ScopeLet sl -> | Topdef (_, typ, e) -> f acc e typ
let _, e = Bindlib.unbind sl.scope_let_next in | ScopeDef (_, scope) ->
get_body_expr_mark e let _, body = Bindlib.unbind scope.scope_body_expr in
| Result e -> let acc, last =
let m = Mark.get e in BoundList.fold_left body ~init:acc ~f:(fun acc sl _ ->
Expr.with_ty m (Mark.add (Expr.mark_pos m) TAny) f acc sl.scope_let_expr sl.scope_let_typ)
in
f acc last (TStruct scope.scope_body_output_struct, Expr.pos last)
in
fst @@ BoundList.fold_left ~f ~init scopes
let typ body =
let pos = Mark.get (StructName.get_info body.scope_body_input_struct) in
let input_typ = Mark.add pos (TStruct body.scope_body_input_struct) in
let result_typ = Mark.add pos (TStruct body.scope_body_output_struct) in
Mark.add pos (TArrow ([input_typ], result_typ))
let get_body_mark scope_body = let get_body_mark scope_body =
let _, e = Bindlib.unbind scope_body.scope_body_expr in let m0 =
get_body_expr_mark e match Bindlib.unbind scope_body.scope_body_expr with
| _, Last (_, m) | _, Cons ({ scope_let_expr = _, m; _ }, _) -> m
in
Expr.with_ty m0 (typ scope_body)
let rec unfold_body_expr (ctx : decl_ctx) (scope_let : 'e scope_body_expr) = let unfold_body_expr (_ctx : decl_ctx) (scope_let : 'e scope_body_expr) =
match scope_let with BoundList.fold_right scope_let ~init:Expr.rebox ~f:(fun sl var acc ->
| Result e -> Expr.rebox e Expr.make_let_in var sl.scope_let_typ
| ScopeLet (Expr.rebox sl.scope_let_expr)
{ acc sl.scope_let_pos)
scope_let_kind = _;
scope_let_typ;
scope_let_expr;
scope_let_next;
scope_let_pos;
} ->
let var, next = Bindlib.unbind scope_let_next in
Expr.make_let_in var scope_let_typ
(Expr.rebox scope_let_expr)
(unfold_body_expr ctx next)
scope_let_pos
let build_typ_from_sig
(_ctx : decl_ctx)
(scope_input_struct_name : StructName.t)
(scope_return_struct_name : StructName.t)
(pos : Pos.t) : typ =
let input_typ = Mark.add pos (TStruct scope_input_struct_name) in
let result_typ = Mark.add pos (TStruct scope_return_struct_name) in
Mark.add pos (TArrow ([input_typ], result_typ))
let input_type ty io = let input_type ty io =
match io, ty with match io, ty with
@ -171,59 +100,34 @@ let input_type ty io =
| (Runtime.Reentrant, iopos), (ty, tpos) -> TDefault (ty, tpos), iopos | (Runtime.Reentrant, iopos), (ty, tpos) -> TDefault (ty, tpos), iopos
| _, ty -> ty | _, ty -> ty
type 'e scope_name_or_var = ScopeName of ScopeName.t | ScopeVar of 'e Var.t let to_expr (ctx : decl_ctx) (body : 'e scope_body) : 'e boxed =
let to_expr (ctx : decl_ctx) (body : 'e scope_body) (mark_scope : 'm) : 'e boxed
=
let var, body_expr = Bindlib.unbind body.scope_body_expr in let var, body_expr = Bindlib.unbind body.scope_body_expr in
let body_expr = unfold_body_expr ctx body_expr in let body_expr = unfold_body_expr ctx body_expr in
let pos = Expr.pos body_expr in
Expr.make_abs [| var |] body_expr Expr.make_abs [| var |] body_expr
[TStruct body.scope_body_input_struct, Expr.mark_pos mark_scope] [TStruct body.scope_body_input_struct, pos]
(Expr.mark_pos mark_scope) pos
let rec unfold let unfold (ctx : decl_ctx) (s : 'e code_item_list) (main_scope : ScopeName.t) :
(ctx : decl_ctx) 'e boxed =
(s : 'e code_item_list) BoundList.fold_lr s ~top:None
(mark : 'm mark) ~down:(fun v item main ->
(main_scope : 'expr scope_name_or_var) : 'e boxed = match main, item with
match s with | None, ScopeDef (name, body) when ScopeName.equal name main_scope ->
| Nil -> ( Some (Expr.make_var v (get_body_mark body))
match main_scope with | r, _ -> r)
| ScopeVar v -> Expr.make_var v mark ~bottom:(fun () -> function Some v -> v | None -> raise Not_found)
| ScopeName _ -> failwith "should not happen") ~up:(fun var item next ->
| Cons (item, next_bind) -> let e, typ =
let var, next = Bindlib.unbind next_bind in
let typ, expr, pos, is_main =
match item with match item with
| ScopeDef (name, body) -> | ScopeDef (_, body) -> to_expr ctx body, typ body
let pos = Mark.get (ScopeName.get_info name) in | Topdef (_, typ, expr) -> Expr.rebox expr, typ
let body_mark = get_body_mark body in
let is_main =
match main_scope with
| ScopeName n -> ScopeName.equal n name
| ScopeVar _ -> false
in in
let typ = Expr.make_let_in var typ e next (Expr.pos e))
build_typ_from_sig ctx body.scope_body_input_struct
body.scope_body_output_struct pos
in
let expr = to_expr ctx body body_mark in
typ, expr, pos, is_main
| Topdef (name, typ, expr) ->
let pos = Mark.get (TopdefName.get_info name) in
typ, Expr.rebox expr, pos, false
in
let main_scope = if is_main then ScopeVar var else main_scope in
let next = unfold ctx next mark main_scope in
Expr.make_let_in var typ expr next pos
let rec free_vars_body_expr scope_lets = let free_vars_body_expr scope_lets =
match scope_lets with BoundList.fold_right scope_lets ~init:Expr.free_vars ~f:(fun sl v acc ->
| Result e -> Expr.free_vars e Var.Set.union (Var.Set.remove v acc) (Expr.free_vars sl.scope_let_expr))
| ScopeLet { scope_let_expr = e; scope_let_next = next; _ } ->
let v, body = Bindlib.unbind next in
Var.Set.union (Expr.free_vars e)
(Var.Set.remove v (free_vars_body_expr body))
let free_vars_item = function let free_vars_item = function
| ScopeDef (_, { scope_body_expr; _ }) -> | ScopeDef (_, { scope_body_expr; _ }) ->
@ -231,9 +135,8 @@ let free_vars_item = function
Var.Set.remove v (free_vars_body_expr body) Var.Set.remove v (free_vars_body_expr body)
| Topdef (_, _, expr) -> Expr.free_vars expr | Topdef (_, _, expr) -> Expr.free_vars expr
let rec free_vars scopes = let free_vars scopes =
match scopes with BoundList.fold_right scopes
| Nil -> Var.Set.empty ~init:(fun () -> Var.Set.empty)
| Cons (item, next_bind) -> ~f:(fun item v acc ->
let v, next = Bindlib.unbind next_bind in Var.Set.union (Var.Set.remove v acc) (free_vars_item item))
Var.Set.union (Var.Set.remove v (free_vars next)) (free_vars_item item)

View File

@ -23,28 +23,8 @@ open Definitions
(** {2 Traversal functions} *) (** {2 Traversal functions} *)
val fold_left_lets :
f:('a -> 'e scope_let -> 'e Var.t -> 'a) ->
init:'a ->
'e scope_body_expr ->
'a
(** Usage:
[fold_left_lets ~f:(fun acc scope_let scope_let_var -> ...) ~init scope_lets],
where [scope_let_var] is the variable bound to the scope let in the next
scope lets to be examined. *)
val fold_right_lets :
f:('expr1 scope_let -> 'expr1 Var.t -> 'a -> 'a) ->
init:('expr1 -> 'a) ->
'expr1 scope_body_expr ->
'a
(** Usage:
[fold_right_lets ~f:(fun scope_let scope_let_var acc -> ...) ~init scope_lets],
where [scope_let_var] is the variable bound to the scope let in the next
scope lets to be examined (which are before in the program order). *)
val map_exprs_in_lets : val map_exprs_in_lets :
?transform_types:(typ -> typ) -> ?typ:(typ -> typ) ->
f:('expr1 -> 'expr2 boxed) -> f:('expr1 -> 'expr2 boxed) ->
varf:('expr1 Var.t -> 'expr2 Var.t) -> varf:('expr1 Var.t -> 'expr2 Var.t) ->
'expr1 scope_body_expr -> 'expr1 scope_body_expr ->
@ -58,48 +38,8 @@ val map_exprs_in_lets :
activated, then the resulting types in the scope let left-hand-sides will be activated, then the resulting types in the scope let left-hand-sides will be
reset to [TAny]. *) reset to [TAny]. *)
val fold_left :
f:('a -> 'expr1 code_item -> 'expr1 Var.t -> 'a) ->
init:'a ->
'expr1 code_item_list ->
'a
(** Usage: [fold_left ~f:(fun acc code_def code_var -> ...) ~init code_def],
where [code_var] is the variable bound to the code item in the next code
items to be examined. *)
val fold_right :
f:('expr1 code_item -> 'expr1 Var.t -> 'a -> 'a) ->
init:'a ->
'expr1 code_item_list ->
'a
(** Usage:
[fold_right_scope ~f:(fun scope_def scope_var acc -> ...) ~init scope_def],
where [scope_var] is the variable bound to the scope in the next scopes to
be examined (which are before in the program order). *)
val map :
f:('e1 code_item -> 'e2 code_item Bindlib.box) ->
varf:('e1 Var.t -> 'e2 Var.t) ->
'e1 code_item_list ->
'e2 code_item_list Bindlib.box
val map_ctx :
f:('ctx -> 'e1 code_item -> 'ctx * 'e2 code_item Bindlib.box) ->
varf:('e1 Var.t -> 'e2 Var.t) ->
'ctx ->
'e1 code_item_list ->
'e2 code_item_list Bindlib.box
(** Similar to [map], but a context is passed left-to-right through the given
function *)
val fold_map :
f:('ctx -> 'e1 Var.t -> 'e1 code_item -> 'ctx * 'e2 code_item Bindlib.box) ->
varf:('e1 Var.t -> 'e2 Var.t) ->
'ctx ->
'e1 code_item_list ->
'ctx * 'e2 code_item_list Bindlib.box
val map_exprs : val map_exprs :
?typ:(typ -> typ) ->
f:('expr1 -> 'expr2 boxed) -> f:('expr1 -> 'expr2 boxed) ->
varf:('expr1 Var.t -> 'expr2 Var.t) -> varf:('expr1 Var.t -> 'expr2 Var.t) ->
'expr1 code_item_list -> 'expr1 code_item_list ->
@ -107,28 +47,20 @@ val map_exprs :
(** This is the main map visitor for all the expressions inside all the scopes (** This is the main map visitor for all the expressions inside all the scopes
of the program. *) of the program. *)
val get_body_mark : (_, 'm) gexpr scope_body -> 'm mark val fold_exprs :
f:('acc -> 'expr -> typ -> 'acc) -> init:'acc -> 'expr code_item_list -> 'acc
(** {2 Conversions} *) (** {2 Conversions} *)
val to_expr : val to_expr : decl_ctx -> ('a any, 'm) gexpr scope_body -> ('a, 'm) boxed_gexpr
decl_ctx -> ('a any, 'm) gexpr scope_body -> 'm mark -> ('a, 'm) boxed_gexpr
(** Usage: [to_expr ctx body scope_position] where [scope_position] corresponds (** Usage: [to_expr ctx body scope_position] where [scope_position] corresponds
to the line of the scope declaration for instance. *) to the line of the scope declaration for instance. *)
type 'e scope_name_or_var = ScopeName of ScopeName.t | ScopeVar of 'e Var.t
val unfold : val unfold :
decl_ctx -> decl_ctx -> ((_, 'm) gexpr as 'e) code_item_list -> ScopeName.t -> 'e boxed
((_, 'm) gexpr as 'e) code_item_list ->
'm mark ->
'e scope_name_or_var ->
'e boxed
val build_typ_from_sig : val typ : _ scope_body -> typ
decl_ctx -> StructName.t -> StructName.t -> Pos.t -> typ (** builds the arrow type for the specified scope *)
(** [build_typ_from_sig ctx in_struct out_struct pos] builds the arrow type for
the specified scope *)
val input_type : typ -> Runtime.io_input Mark.pos -> typ val input_type : typ -> Runtime.io_input Mark.pos -> typ
(** Returns the correct input type for scope input variables: this is [typ] for (** Returns the correct input type for scope input variables: this is [typ] for

View File

@ -20,6 +20,7 @@ module Qident = Qident
module Type = Type module Type = Type
module Operator = Operator module Operator = Operator
module Expr = Expr module Expr = Expr
module BoundList = BoundList
module Scope = Scope module Scope = Scope
module Program = Program module Program = Program
module Print = Print module Print = Print

View File

@ -944,46 +944,38 @@ let check_expr ctx ?env ?typ e =
let expr ctx ?(env = Env.empty ctx) ?typ e = let expr ctx ?(env = Env.empty ctx) ?typ e =
Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) (expr_raw ctx ~env ?typ e) Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) (expr_raw ctx ~env ?typ e)
let rec scope_body_expr ctx env ty_out body_expr = let scope_body_expr ctx env ty_out body_expr =
match body_expr with let _env, ret =
| A.Result e -> BoundList.fold_map body_expr ~init:env
~last:(fun env e ->
let e' = wrap_expr ctx (typecheck_expr_top_down ctx env ty_out) e in let e' = wrap_expr ctx (typecheck_expr_top_down ctx env ty_out) e in
let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in
Bindlib.box_apply (fun e -> A.Result e) (Expr.Box.lift e') env, Expr.Box.lift e')
| A.ScopeLet ~f:(fun env var scope ->
{ let e0 = scope.A.scope_let_expr in
scope_let_kind; let ty_e = ast_to_typ scope.A.scope_let_typ in
scope_let_typ;
scope_let_expr = e0;
scope_let_next;
scope_let_pos;
} ->
let ty_e = ast_to_typ scope_let_typ in
let e = wrap_expr ctx (typecheck_expr_bottom_up ctx env) e0 in let e = wrap_expr ctx (typecheck_expr_bottom_up ctx env) e0 in
wrap ctx (fun t -> unify ctx e0 (ty e) t) ty_e; wrap ctx (fun t -> unify ctx e0 (ty e) t) ty_e;
(* We could use [typecheck_expr_top_down] rather than this manual (* We could use [typecheck_expr_top_down] rather than this manual
unification, but we get better messages with this order of the [unify] unification, but we get better messages with this order of the
parameters, which keeps location of the type as defined instead of as [unify] parameters, which keeps location of the type as defined
inferred. *) instead of as inferred. *)
let var, next = Bindlib.unbind scope_let_next in ( Env.add var ty_e env,
let env = Env.add var ty_e env in Var.translate var,
let next = scope_body_expr ctx env ty_out next in Bindlib.box_apply
let scope_let_next = Bindlib.bind_var (Var.translate var) next in (fun scope_let_expr ->
Bindlib.box_apply2
(fun scope_let_expr scope_let_next ->
A.ScopeLet
{ {
scope_let_kind; scope with
scope_let_typ = A.scope_let_typ =
(match Mark.remove scope_let_typ with (match scope.A.scope_let_typ with
| TAny -> typ_to_ast ~flags:env.flags (ty e) | TAny, _ -> typ_to_ast ~flags:env.flags (ty e)
| _ -> scope_let_typ); | ty -> ty);
scope_let_expr; A.scope_let_expr;
scope_let_next;
scope_let_pos;
}) })
(Expr.Box.lift (Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e)) (Expr.Box.lift (Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e))
scope_let_next ))
in
ret
let scope_body ctx env body = let scope_body ctx env body =
let get_pos struct_name = Mark.get (A.StructName.get_info struct_name) in let get_pos struct_name = Mark.get (A.StructName.get_info struct_name) in
@ -1003,33 +995,29 @@ let scope_body ctx env body =
(get_pos body.A.scope_body_output_struct) (get_pos body.A.scope_body_output_struct)
(TArrow ([ty_in], ty_out))) ) (TArrow ([ty_in], ty_out))) )
let rec scopes ctx env = function let scopes ctx env =
| A.Nil -> Bindlib.box A.Nil, env BoundList.fold_map ~init:env
| A.Cons (item, next_bind) -> ~last:(fun ctx () -> ctx, Bindlib.box ())
let var, next = Bindlib.unbind next_bind in ~f:(fun env var item ->
let env, def =
match item with match item with
| A.ScopeDef (name, body) -> | A.ScopeDef (name, body) ->
let body_e, ty_scope = scope_body ctx env body in let body_e, ty_scope = scope_body ctx env body in
( Env.add var ty_scope env, ( Env.add var ty_scope env,
Var.translate var,
Bindlib.box_apply (fun body -> A.ScopeDef (name, body)) body_e ) Bindlib.box_apply (fun body -> A.ScopeDef (name, body)) body_e )
| A.Topdef (name, typ, e) -> | A.Topdef (name, typ, e) ->
let e' = expr_raw ctx ~env ~typ e in let e' = expr_raw ctx ~env ~typ e in
let (A.Custom { custom = uf; _ }) = Mark.get e' in let (A.Custom { custom = uf; _ }) = Mark.get e' in
let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in
( Env.add var uf env, ( Env.add var uf env,
Var.translate var,
Bindlib.box_apply Bindlib.box_apply
(fun e -> A.Topdef (name, Expr.ty e', e)) (fun e -> A.Topdef (name, Expr.ty e', e))
(Expr.Box.lift e') ) (Expr.Box.lift e') ))
in
let next', env = scopes ctx env next in
let next_bind' = Bindlib.bind_var (Var.translate var) next' in
( Bindlib.box_apply2 (fun item next -> A.Cons (item, next)) def next_bind',
env )
let program ?fail_on_any ?assume_op_types prg = let program ?fail_on_any ?assume_op_types prg =
let env = Env.empty ?fail_on_any ?assume_op_types prg.A.decl_ctx in let env = Env.empty ?fail_on_any ?assume_op_types prg.A.decl_ctx in
let code_items, new_env = scopes prg.A.decl_ctx env prg.A.code_items in let new_env, code_items = scopes prg.A.decl_ctx env prg.A.code_items in
{ {
A.lang = prg.lang; A.lang = prg.lang;
A.module_name = prg.A.module_name; A.module_name = prg.A.module_name;

View File

@ -5,7 +5,6 @@ let () =
( "Iota-reduction", ( "Iota-reduction",
[ [
test_case "#1" `Quick Shared_ast.Optimizations.test_iota_reduction_1; test_case "#1" `Quick Shared_ast.Optimizations.test_iota_reduction_1;
test_case "#2" `Quick test_case "#2" `Quick Shared_ast.Optimizations.test_iota_reduction_2;
Shared_ast.Optimizations.test_iota_reduction_2;
] ); ] );
] ]

View File

@ -286,11 +286,9 @@ let rec generate_verification_conditions_scope_body_expr
(scope_body_expr : 'm expr scope_body_expr) : (scope_body_expr : 'm expr scope_body_expr) :
ctx * verification_condition list * typed expr list = ctx * verification_condition list * typed expr list =
match scope_body_expr with match scope_body_expr with
| Result _ -> ctx, [], [] | Last _ -> ctx, [], []
| ScopeLet scope_let -> | Cons (scope_let, scope_let_next) ->
let scope_let_var, scope_let_next = let scope_let_var, scope_let_next = Bindlib.unbind scope_let_next in
Bindlib.unbind scope_let.scope_let_next
in
let new_ctx, vc_list, assert_list = let new_ctx, vc_list, assert_list =
match scope_let.scope_let_kind with match scope_let.scope_let_kind with
| Assertion -> ( | Assertion -> (
@ -378,7 +376,8 @@ let generate_verification_conditions_code_items
(decl_ctx : decl_ctx) (decl_ctx : decl_ctx)
(code_items : 'm expr code_item_list) (code_items : 'm expr code_item_list)
(s : ScopeName.t option) : verification_condition list = (s : ScopeName.t option) : verification_condition list =
Scope.fold_left let conditions, () =
BoundList.fold_left
~f:(fun vcs item _ -> ~f:(fun vcs item _ ->
match item with match item with
| Topdef _ -> [] | Topdef _ -> []
@ -402,9 +401,10 @@ let generate_verification_conditions_code_items
scope_variables_typs = scope_variables_typs =
Var.Map.empty Var.Map.empty
(* We don't need to add the typ of the scope input var here (* We don't need to add the typ of the scope input var here
because it will never appear in an expression for which we because it will never appear in an expression for which
generate a verification conditions (the big struct is we generate a verification conditions (the big struct is
destructured with a series of let bindings just after. )*); destructured with a series of let bindings just after.
)*);
} }
in in
let _, vcs, asserts = let _, vcs, asserts =
@ -421,6 +421,8 @@ let generate_verification_conditions_code_items
in in
new_vcs @ vcs) new_vcs @ vcs)
~init:[] code_items ~init:[] code_items
in
conditions
let generate_verification_conditions (p : 'm program) (s : ScopeName.t option) : let generate_verification_conditions (p : 'm program) (s : ScopeName.t option) :
verification_condition list = verification_condition list =

View File

@ -25,7 +25,6 @@ module S_in = struct
end end
let s (s_in: S_in.t) : S.t = let s (s_in: S_in.t) : S.t =
let sr_: money = let sr_: money =
try try
@ -70,6 +69,7 @@ let s (s_in: S_in.t) : S.t =
let half_ : integer -> decimal = let half_ : integer -> decimal =
fun (x_: integer) -> o_div_int_int x_ (integer_of_string "2") fun (x_: integer) -> o_div_int_int x_ (integer_of_string "2")
let () = let () =
Runtime_ocaml.Runtime.register_module "Mod_def" Runtime_ocaml.Runtime.register_module "Mod_def"
[ "S", Obj.repr s; [ "S", Obj.repr s;

View File

@ -52,7 +52,6 @@ module S_in = struct
end end
let s (s_in: S_in.t) : S.t = let s (s_in: S_in.t) : S.t =
let a_: unit -> bool = s_in.S_in.a_in in let a_: unit -> bool = s_in.S_in.a_in in
let a_: bool = let a_: bool =
@ -91,6 +90,7 @@ let s (s_in: S_in.t) : S.t =
start_line=7; start_column=18; end_line=7; end_column=19; start_line=7; start_column=18; end_line=7; end_column=19;
law_headings=["Article"]})) in law_headings=["Article"]})) in
{S.a = a_} {S.a = a_}
let () = let () =
Runtime_ocaml.Runtime.register_module "Let_in2" Runtime_ocaml.Runtime.register_module "Let_in2"
[ "S", Obj.repr s ] [ "S", Obj.repr s ]

View File

@ -50,7 +50,6 @@ module ScopeB_in = struct
end end
let scope_a (scope_a_in: ScopeA_in.t) : ScopeA.t = let scope_a (scope_a_in: ScopeA_in.t) : ScopeA.t =
let a_: bool = true in let a_: bool = true in
{ScopeA.a = a_} {ScopeA.a = a_}
@ -60,6 +59,7 @@ let scope_b (scope_b_in: ScopeB_in.t) : ScopeB.t =
let scope_a_dot_a_: bool = result_.ScopeA.a in let scope_a_dot_a_: bool = result_.ScopeA.a in
let a_: bool = scope_a_dot_a_ in let a_: bool = scope_a_dot_a_ in
{ScopeB.a = a_} {ScopeB.a = a_}
Generating entry points for scopes: ScopeA ScopeB Generating entry points for scopes: ScopeA ScopeB
let entry_scopes = [ let entry_scopes = [

View File

@ -40,7 +40,7 @@ $ catala Interpret -t -s HousingComputation --debug
[DEBUG] Translating to default calculus... [DEBUG] Translating to default calculus...
[DEBUG] Typechecking again... [DEBUG] Typechecking again...
[DEBUG] Starting interpretation... [DEBUG] Starting interpretation...
[LOG] ≔ HousingComputation.f: λ (x_76: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨(let result_77 : RentComputation = (#{→ RentComputation.direct} (λ (RentComputation_in_78: RentComputation_in) → let g_79 : integer → integer = #{≔ RentComputation.g} (λ (x1_80: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_80 +! 1⟩⟩ | false ⊢ ∅ ⟩) in let f_81 : integer → integer = #{≔ RentComputation.f} (λ (x1_82: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} g_79) #{≔ RentComputation.g.input0} (x1_82 +! 1)⟩⟩ | false ⊢ ∅ ⟩) in { RentComputation f = f_81; })) #{≔ RentComputation.direct.input} {RentComputation_in} in let result1_83 : RentComputation = { RentComputation f = λ (param0_84: integer) → #{← RentComputation.f} #{≔ RentComputation.f.output} (#{→ RentComputation.f} result_77.f) #{≔ RentComputation.f.input0} param0_84; } in #{← RentComputation.direct} #{≔ RentComputation.direct.output} if #{☛ RentComputation.direct.output} true then result1_83 else result1_83).f x_76⟩⟩ | false ⊢ ∅ ⟩ [LOG] ≔ HousingComputation.f: λ (x_67: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨(let result_68 : RentComputation = (#{→ RentComputation.direct} (λ (RentComputation_in_69: RentComputation_in) → let g_70 : integer → integer = #{≔ RentComputation.g} (λ (x1_71: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_71 +! 1⟩⟩ | false ⊢ ∅ ⟩) in let f_72 : integer → integer = #{≔ RentComputation.f} (λ (x1_73: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} g_70) #{≔ RentComputation.g.input0} (x1_73 +! 1)⟩⟩ | false ⊢ ∅ ⟩) in { RentComputation f = f_72; })) #{≔ RentComputation.direct.input} {RentComputation_in} in let result1_74 : RentComputation = { RentComputation f = λ (param0_75: integer) → #{← RentComputation.f} #{≔ RentComputation.f.output} (#{→ RentComputation.f} result_68.f) #{≔ RentComputation.f.input0} param0_75; } in #{← RentComputation.direct} #{≔ RentComputation.direct.output} if #{☛ RentComputation.direct.output} true then result1_74 else result1_74).f x_67⟩⟩ | false ⊢ ∅ ⟩
[LOG] ☛ Definition applied: [LOG] ☛ Definition applied:
┌─⯈ tests/test_scope/good/scope_call3.catala_en:8.14-8.20: ┌─⯈ tests/test_scope/good/scope_call3.catala_en:8.14-8.20:
└─┐ └─┐
@ -55,14 +55,14 @@ $ catala Interpret -t -s HousingComputation --debug
│ ‾ │ ‾
[LOG] → RentComputation.direct [LOG] → RentComputation.direct
[LOG] ≔ RentComputation.direct.input: {RentComputation_in} [LOG] ≔ RentComputation.direct.input: {RentComputation_in}
[LOG] ≔ RentComputation.g: λ (x_85: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x_85 +! 1⟩⟩ | false ⊢ ∅ ⟩ [LOG] ≔ RentComputation.g: λ (x_76: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x_76 +! 1⟩⟩ | false ⊢ ∅ ⟩
[LOG] ≔ RentComputation.f: λ (x_86: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} (λ (x1_87: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_87 +! 1⟩⟩ | false ⊢ ∅ ⟩)) #{≔ RentComputation.g.input0} (x_86 +! 1)⟩⟩ | false ⊢ ∅ ⟩ [LOG] ≔ RentComputation.f: λ (x_77: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} (λ (x1_78: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_78 +! 1⟩⟩ | false ⊢ ∅ ⟩)) #{≔ RentComputation.g.input0} (x_77 +! 1)⟩⟩ | false ⊢ ∅ ⟩
[LOG] ☛ Definition applied: [LOG] ☛ Definition applied:
┌─⯈ tests/test_scope/good/scope_call3.catala_en:7.29-7.54: ┌─⯈ tests/test_scope/good/scope_call3.catala_en:7.29-7.54:
└─┐ └─┐
7 │ definition f of x equals (output of RentComputation).f of x 7 │ definition f of x equals (output of RentComputation).f of x
│ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ │ ‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾
[LOG] ≔ RentComputation.direct.output: { RentComputation f = λ (param0_88: integer) → #{← RentComputation.f} #{≔ RentComputation.f.output} (#{→ RentComputation.f} { RentComputation f = λ (x_89: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} (λ (x1_90: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_90 +! 1⟩⟩ | false ⊢ ∅ ⟩)) #{≔ RentComputation.g.input0} (x_89 +! 1)⟩⟩ | false ⊢ ∅ ⟩; }.f) #{≔ RentComputation.f.input0} param0_88; } [LOG] ≔ RentComputation.direct.output: { RentComputation f = λ (param0_79: integer) → #{← RentComputation.f} #{≔ RentComputation.f.output} (#{→ RentComputation.f} { RentComputation f = λ (x_80: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨#{← RentComputation.g} #{≔ RentComputation.g.output} (#{→ RentComputation.g} (λ (x1_81: integer) → error_empty ⟨ ⟨#{☛ } true ⊢ ⟨x1_81 +! 1⟩⟩ | false ⊢ ∅ ⟩)) #{≔ RentComputation.g.input0} (x_80 +! 1)⟩⟩ | false ⊢ ∅ ⟩; }.f) #{≔ RentComputation.f.input0} param0_79; }
[LOG] ← RentComputation.direct [LOG] ← RentComputation.direct
[LOG] → RentComputation.f [LOG] → RentComputation.f
[LOG] ≔ RentComputation.f.input0: 1 [LOG] ≔ RentComputation.f.input0: 1