diff --git a/compiler/dcalc/optimizations.ml b/compiler/dcalc/optimizations.ml index b9b4aaf2..4b8eb6f8 100644 --- a/compiler/dcalc/optimizations.ml +++ b/compiler/dcalc/optimizations.ml @@ -190,38 +190,71 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked) let optimize_expr (decl_ctx : decl_ctx) (e : expr Pos.marked) = partial_evaluation { var_values = VarMap.empty; decl_ctx } e +let rec scope_lets_map + (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) + (ctx : 'a) + (scope_body_expr : scope_body_expr) : scope_body_expr Bindlib.box = + match scope_body_expr with + | Result e -> Bindlib.box_apply (fun e' -> Result e') (t ctx e) + | ScopeLet scope_let -> + let var, next = Bindlib.unbind scope_let.scope_let_next in + let new_scope_let_expr = t ctx scope_let.scope_let_expr in + let new_next = scope_lets_map t ctx next in + let new_next = Bindlib.bind_var var new_next in + Bindlib.box_apply2 + (fun new_scope_let_expr new_next -> + ScopeLet + { + scope_let with + scope_let_expr = new_scope_let_expr; + scope_let_next = new_next; + }) + new_scope_let_expr new_next + +let rec scopes_map + (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) + (ctx : 'a) + (scopes : scopes) : scopes Bindlib.box = + match scopes with + | Nil -> Bindlib.box Nil + | ScopeDef scope_def -> + let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in + let scope_arg_var, scope_body_expr = + Bindlib.unbind scope_def.scope_body.scope_body_expr + in + let new_scope_body_expr = scope_lets_map t ctx scope_body_expr in + let new_scope_body_expr = + Bindlib.bind_var scope_arg_var new_scope_body_expr + in + let new_scope_next = scopes_map t ctx scope_next in + let new_scope_next = Bindlib.bind_var scope_var new_scope_next in + Bindlib.box_apply2 + (fun new_scope_body_expr new_scope_next -> + ScopeDef + { + scope_def with + scope_next = new_scope_next; + scope_body = + { + scope_def.scope_body with + scope_body_expr = new_scope_body_expr; + }; + }) + new_scope_body_expr new_scope_next + let program_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx : 'a) - (p : program) : program = - { - p with - scopes = - List.map - (fun (s_name, s_var, s_body) -> - let new_s_body = - { - s_body with - scope_body_lets = - List.map - (fun scope_let -> - { - scope_let with - scope_let_expr = - Bindlib.unbox - (Bindlib.box_apply (t ctx) scope_let.scope_let_expr); - }) - s_body.scope_body_lets; - } - in - (s_name, s_var, new_s_body)) - p.scopes; - } + (p : program) : program Bindlib.box = + Bindlib.box_apply + (fun new_scopes -> { p with scopes = new_scopes }) + (scopes_map t ctx p.scopes) let optimize_program (p : program) : program = - program_map partial_evaluation - { var_values = VarMap.empty; decl_ctx = p.decl_ctx } - p + Bindlib.unbox + (program_map partial_evaluation + { var_values = VarMap.empty; decl_ctx = p.decl_ctx } + p) let rec remove_all_logs (e : expr Pos.marked) : expr Pos.marked Bindlib.box = let pos = Pos.get_position e in diff --git a/compiler/driver.ml b/compiler/driver.ml index 10a27b19..6ba9a823 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -215,10 +215,17 @@ let driver source_file (options : Cli.options) : int = if Option.is_some options.ex_scope then Format.fprintf fmt "%a\n" (Dcalc.Print.format_scope ~debug:options.debug prgm.decl_ctx) - (let _, _, s = - List.find (fun (name, _, _) -> name = scope_uid) prgm.scopes - in - (scope_uid, s)) + ( scope_uid, + Option.get + (Dcalc.Ast.fold_scope_defs ~init:None + ~f:(fun acc scope_def -> + if + Dcalc.Ast.ScopeName.compare scope_def.scope_name + scope_uid + = 0 + then Some scope_def.scope_body + else acc) + prgm.scopes) ) else Format.fprintf fmt "%a\n" (Dcalc.Print.format_expr prgm.decl_ctx) diff --git a/compiler/lcalc/compile_with_exceptions.ml b/compiler/lcalc/compile_with_exceptions.ml index f7068237..3b9c6bf2 100644 --- a/compiler/lcalc/compile_with_exceptions.ml +++ b/compiler/lcalc/compile_with_exceptions.ml @@ -147,32 +147,33 @@ and translate_expr (ctx : ctx) (e : D.expr Pos.marked) : | D.EDefault (exceptions, just, cons) -> translate_default ctx exceptions just cons (Pos.get_position e) +let rec translate_scopes + (decl_ctx : D.decl_ctx) (ctx : A.Var.t D.VarMap.t) (scopes : D.scopes) : + A.scope_body list = + match scopes with + | Nil -> [] + | ScopeDef scope_def -> + let scope_var, scope_next = Bindlib.unbind scope_def.scope_next in + let new_n = A.Var.make (Bindlib.name_of scope_var, Pos.no_pos) in + let new_scope = + { + Ast.scope_body_name = scope_def.scope_name; + scope_body_var = new_n; + scope_body_expr = + Bindlib.unbox + (translate_expr + (D.VarMap.map (fun v -> A.make_var (v, Pos.no_pos)) ctx) + (Bindlib.unbox + (D.build_whole_scope_expr decl_ctx scope_def.scope_body + (Pos.get_position + (Dcalc.Ast.ScopeName.get_info scope_def.scope_name))))); + } + in + let new_ctx = D.VarMap.add scope_var new_n ctx in + new_scope :: translate_scopes decl_ctx new_ctx scope_next + let translate_program (prgm : D.program) : A.program = { - scopes = - (let acc, _ = - List.fold_left - (fun ((acc, ctx) : _ * A.Var.t D.VarMap.t) (scope_name, n, e) -> - let new_n = A.Var.make (Bindlib.name_of n, Pos.no_pos) in - let new_acc = - { - Ast.scope_body_name = scope_name; - scope_body_var = new_n; - scope_body_expr = - Bindlib.unbox - (translate_expr - (D.VarMap.map (fun v -> A.make_var (v, Pos.no_pos)) ctx) - (Bindlib.unbox - (D.build_whole_scope_expr prgm.decl_ctx e - (Pos.get_position - (Dcalc.Ast.ScopeName.get_info scope_name))))); - } - :: acc - in - let new_ctx = D.VarMap.add n new_n ctx in - (new_acc, new_ctx)) - ([], D.VarMap.empty) prgm.scopes - in - List.rev acc); + scopes = translate_scopes prgm.decl_ctx D.VarMap.empty prgm.scopes; decl_ctx = prgm.decl_ctx; } diff --git a/compiler/scopelang/scope_to_dcalc.ml b/compiler/scopelang/scope_to_dcalc.ml index dea17275..6547e44d 100644 --- a/compiler/scopelang/scope_to_dcalc.ml +++ b/compiler/scopelang/scope_to_dcalc.ml @@ -335,14 +335,16 @@ let rec translate_expr (ctx : ctx) (e : Ast.expr Pos.marked) : (** The result of a rule translation is a list of assignment, with variables and expressions. We also return the new translation context available after the - assignment to use in later rule translations. The list is actually a list of - list because we want to group in assignments that are independent of each - other to speed up the translation by minimizing Bindlib.bind_mvar *) + assignment to use in later rule translations. The list is actually a + continuation yielding a [Dcalc.scope_body_expr] by giving it what should + come later in the chain of let-bindings. *) let translate_rule (ctx : ctx) (rule : Ast.rule) ((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) : - Dcalc.Ast.scope_let list * ctx = + (Dcalc.Ast.scope_body_expr Bindlib.box -> + Dcalc.Ast.scope_body_expr Bindlib.box) + * ctx = match rule with | Definition ((ScopeVar a, var_def_pos), tau, a_io, e) -> let a_name = Ast.ScopeVar.get_info (Pos.unmark a) in @@ -367,14 +369,19 @@ let translate_rule (Dcalc.Ast.VarDef (Pos.unmark tau)) [ (sigma_name, pos_sigma); a_name ] in - ( [ - { - Dcalc.Ast.scope_let_var = (a_var, Pos.get_position a); - Dcalc.Ast.scope_let_typ = tau; - Dcalc.Ast.scope_let_expr = merged_expr; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.ScopeVarDefinition; - }; - ], + ( (fun next -> + Bindlib.box_apply2 + (fun next merged_expr -> + Dcalc.Ast.ScopeLet + { + Dcalc.Ast.scope_let_next = next; + Dcalc.Ast.scope_let_typ = tau; + Dcalc.Ast.scope_let_expr = merged_expr; + Dcalc.Ast.scope_let_kind = Dcalc.Ast.ScopeVarDefinition; + Dcalc.Ast.scope_let_pos = Pos.get_position a; + }) + (Bindlib.bind_var a_var next) + merged_expr), { ctx with scope_vars = @@ -416,20 +423,25 @@ let translate_rule [ (Dcalc.Ast.TLit TUnit, var_def_pos) ] var_def_pos in - ( [ - { - Dcalc.Ast.scope_let_var = (a_var, Pos.get_position a_name); - Dcalc.Ast.scope_let_typ = - (match Pos.unmark a_io.io_input with - | NoInput -> failwith "should not happen" - | OnlyInput -> tau - | Reentrant -> - ( Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), - var_def_pos )); - Dcalc.Ast.scope_let_expr = thunked_or_nonempty_new_e; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.SubScopeVarDefinition; - }; - ], + ( (fun next -> + Bindlib.box_apply2 + (fun next thunked_or_nonempty_new_e -> + Dcalc.Ast.ScopeLet + { + Dcalc.Ast.scope_let_next = next; + Dcalc.Ast.scope_let_pos = Pos.get_position a_name; + Dcalc.Ast.scope_let_typ = + (match Pos.unmark a_io.io_input with + | NoInput -> failwith "should not happen" + | OnlyInput -> tau + | Reentrant -> + ( Dcalc.Ast.TArrow ((TLit TUnit, var_def_pos), tau), + var_def_pos )); + Dcalc.Ast.scope_let_expr = thunked_or_nonempty_new_e; + Dcalc.Ast.scope_let_kind = Dcalc.Ast.SubScopeVarDefinition; + }) + (Bindlib.bind_var a_var next) + thunked_or_nonempty_new_e), { ctx with subscope_vars = @@ -543,38 +555,51 @@ let translate_rule Some called_scope_return_struct ), pos_sigma ) in - let call_scope_let = - { - Dcalc.Ast.scope_let_var = (result_tuple_var, pos_sigma); - Dcalc.Ast.scope_let_kind = Dcalc.Ast.CallingSubScope; - Dcalc.Ast.scope_let_typ = result_tuple_typ; - Dcalc.Ast.scope_let_expr = call_expr; - } + let call_scope_let (next : Dcalc.Ast.scope_body_expr Bindlib.box) = + Bindlib.box_apply2 + (fun next call_expr -> + Dcalc.Ast.ScopeLet + { + Dcalc.Ast.scope_let_next = next; + Dcalc.Ast.scope_let_pos = pos_sigma; + Dcalc.Ast.scope_let_kind = Dcalc.Ast.CallingSubScope; + Dcalc.Ast.scope_let_typ = result_tuple_typ; + Dcalc.Ast.scope_let_expr = call_expr; + }) + (Bindlib.bind_var result_tuple_var next) + call_expr in - let result_bindings_lets = - List.mapi - (fun i (var_ctx, v) -> - { - Dcalc.Ast.scope_let_var = (v, pos_sigma); - Dcalc.Ast.scope_let_typ = (var_ctx.scope_var_typ, pos_sigma); - Dcalc.Ast.scope_let_kind = Dcalc.Ast.DestructuringSubScopeResults; - Dcalc.Ast.scope_let_expr = - Bindlib.box_apply - (fun r -> - ( Dcalc.Ast.ETupleAccess - ( r, - i, - Some called_scope_return_struct, - List.map - (fun (var_ctx, _) -> - (var_ctx.scope_var_typ, pos_sigma)) - all_subscope_output_vars_dcalc ), - pos_sigma )) - (Dcalc.Ast.make_var (result_tuple_var, pos_sigma)); - }) + let result_bindings_lets (next : Dcalc.Ast.scope_body_expr Bindlib.box) = + List.fold_right + (fun (var_ctx, v) (next, i) -> + ( Bindlib.box_apply2 + (fun next r -> + Dcalc.Ast.ScopeLet + { + Dcalc.Ast.scope_let_next = next; + Dcalc.Ast.scope_let_pos = pos_sigma; + Dcalc.Ast.scope_let_typ = + (var_ctx.scope_var_typ, pos_sigma); + Dcalc.Ast.scope_let_kind = + Dcalc.Ast.DestructuringSubScopeResults; + Dcalc.Ast.scope_let_expr = + ( Dcalc.Ast.ETupleAccess + ( r, + i, + Some called_scope_return_struct, + List.map + (fun (var_ctx, _) -> + (var_ctx.scope_var_typ, pos_sigma)) + all_subscope_output_vars_dcalc ), + pos_sigma ); + }) + (Bindlib.bind_var v next) + (Dcalc.Ast.make_var (result_tuple_var, pos_sigma)), + i - 1 )) all_subscope_output_vars_dcalc + (next, List.length all_subscope_output_vars_dcalc - 1) in - ( call_scope_let :: result_bindings_lets, + ( (fun next -> call_scope_let (fst (result_bindings_lets next))), { ctx with subscope_vars = @@ -589,24 +614,28 @@ let translate_rule } ) | Assertion e -> let new_e = translate_expr ctx e in - ( [ - { - Dcalc.Ast.scope_let_var = - (Dcalc.Ast.Var.make ("_", Pos.get_position e), Pos.get_position e); - Dcalc.Ast.scope_let_typ = (Dcalc.Ast.TLit TUnit, Pos.get_position e); - Dcalc.Ast.scope_let_expr = - (* To ensure that we throw an error if the value is not defined, - we add an check "ErrorOnEmpty" here. *) - Bindlib.box_apply - (fun new_e -> - Pos.same_pos_as - (Dcalc.Ast.EAssert - (Dcalc.Ast.ErrorOnEmpty new_e, Pos.get_position e)) - e) - new_e; - Dcalc.Ast.scope_let_kind = Dcalc.Ast.Assertion; - }; - ], + ( (fun next -> + Bindlib.box_apply2 + (fun next new_e -> + Dcalc.Ast.ScopeLet + { + Dcalc.Ast.scope_let_next = next; + Dcalc.Ast.scope_let_pos = Pos.get_position e; + Dcalc.Ast.scope_let_typ = + (Dcalc.Ast.TLit TUnit, Pos.get_position e); + Dcalc.Ast.scope_let_expr = + (* To ensure that we throw an error if the value is not + defined, we add an check "ErrorOnEmpty" here. *) + Pos.same_pos_as + (Dcalc.Ast.EAssert + (Dcalc.Ast.ErrorOnEmpty new_e, Pos.get_position e)) + new_e; + Dcalc.Ast.scope_let_kind = Dcalc.Ast.Assertion; + }) + (Bindlib.bind_var + (Dcalc.Ast.Var.make ("_", Pos.get_position e)) + next) + new_e), ctx ) let translate_rules @@ -614,15 +643,16 @@ let translate_rules (rules : Ast.rule list) ((sigma_name, pos_sigma) : Utils.Uid.MarkedString.info) (sigma_return_struct_name : Ast.StructName.t) : - Dcalc.Ast.scope_let list * Dcalc.Ast.expr Pos.marked Bindlib.box * ctx = + Dcalc.Ast.scope_body_expr Bindlib.box * ctx = let scope_lets, new_ctx = List.fold_left (fun (scope_lets, ctx) rule -> let new_scope_lets, new_ctx = translate_rule ctx rule (sigma_name, pos_sigma) in - (scope_lets @ new_scope_lets, new_ctx)) - ([], ctx) rules + ((fun next -> scope_lets (new_scope_lets next)), new_ctx)) + ((fun next -> next), ctx) + rules in let scope_variables = Ast.ScopeVarMap.bindings new_ctx.scope_vars in let scope_output_variables = @@ -640,14 +670,19 @@ let translate_rules Dcalc.Ast.make_var (dcalc_var, pos_sigma)) scope_output_variables)) in - (scope_lets, return_exp, new_ctx) + ( scope_lets + (Bindlib.box_apply + (fun return_exp -> Dcalc.Ast.Result return_exp) + return_exp), + new_ctx ) let translate_scope_decl (struct_ctx : Ast.struct_ctx) (enum_ctx : Ast.enum_ctx) (sctx : scope_sigs_ctx) (scope_name : Ast.ScopeName.t) - (sigma : Ast.scope_decl) : Dcalc.Ast.scope_body * Dcalc.Ast.struct_ctx = + (sigma : Ast.scope_decl) : + Dcalc.Ast.scope_body Bindlib.box * Dcalc.Ast.struct_ctx = let sigma_info = Ast.ScopeName.get_info sigma.scope_decl_name in let scope_sig = Ast.ScopeMap.find sigma.scope_decl_name sctx in let scope_variables = scope_sig.scope_sig_local_vars in @@ -679,7 +714,7 @@ let translate_scope_decl let scope_input_struct_name = scope_sig.scope_sig_input_struct in let scope_return_struct_name = scope_sig.scope_sig_output_struct in let pos_sigma = Pos.get_position sigma_info in - let rules, return_exp, ctx = + let rules_with_return_expr, ctx = translate_rules ctx sigma.scope_decl_rules sigma_info scope_return_struct_name in @@ -716,27 +751,34 @@ let translate_scope_decl pos_sigma ) | NoInput -> failwith "should not happen" in - let input_destructurings = - List.mapi - (fun i (var_ctx, v) -> - { - Dcalc.Ast.scope_let_kind = Dcalc.Ast.DestructuringInputStruct; - Dcalc.Ast.scope_let_var = (v, pos_sigma); - Dcalc.Ast.scope_let_typ = input_var_typ var_ctx; - Dcalc.Ast.scope_let_expr = - Bindlib.box_apply - (fun r -> - ( Dcalc.Ast.ETupleAccess - ( r, - i, - Some scope_input_struct_name, - List.map - (fun (var_ctx, _) -> input_var_typ var_ctx) - scope_input_variables ), - pos_sigma )) - (Dcalc.Ast.make_var (scope_input_var, pos_sigma)); - }) - scope_input_variables + let input_destructurings (next : Dcalc.Ast.scope_body_expr Bindlib.box) = + fst + (List.fold_right + (fun (var_ctx, v) (next, i) -> + ( Bindlib.box_apply2 + (fun next r -> + Dcalc.Ast.ScopeLet + { + Dcalc.Ast.scope_let_kind = + Dcalc.Ast.DestructuringInputStruct; + Dcalc.Ast.scope_let_next = next; + Dcalc.Ast.scope_let_pos = pos_sigma; + Dcalc.Ast.scope_let_typ = input_var_typ var_ctx; + Dcalc.Ast.scope_let_expr = + ( Dcalc.Ast.ETupleAccess + ( r, + i, + Some scope_input_struct_name, + List.map + (fun (var_ctx, _) -> input_var_typ var_ctx) + scope_input_variables ), + pos_sigma ); + }) + (Bindlib.bind_var v next) + (Dcalc.Ast.make_var (scope_input_var, pos_sigma)), + i - 1 )) + scope_input_variables + (next, List.length scope_input_variables - 1)) in let scope_return_struct_fields = List.map @@ -761,13 +803,15 @@ let translate_scope_decl (Ast.StructMap.singleton scope_return_struct_name scope_return_struct_fields) in - ( { - Dcalc.Ast.scope_body_lets = input_destructurings @ rules; - Dcalc.Ast.scope_body_result = return_exp; - Dcalc.Ast.scope_body_input_struct = scope_input_struct_name; - Dcalc.Ast.scope_body_output_struct = scope_return_struct_name; - Dcalc.Ast.scope_body_arg = scope_input_var; - }, + ( Bindlib.box_apply + (fun scope_body_expr -> + { + Dcalc.Ast.scope_body_expr; + Dcalc.Ast.scope_body_input_struct = scope_input_struct_name; + Dcalc.Ast.scope_body_output_struct = scope_return_struct_name; + }) + (Bindlib.bind_var scope_input_var + (input_destructurings rules_with_return_expr)), new_struct_ctx ) let translate_program (prgm : Ast.program) :