diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index 60aec150..575d7810 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -27,6 +27,10 @@ type scope_input_var_ctx = { scope_input_name : StructField.t; scope_input_io : Runtime.io_input Mark.pos; scope_input_typ : naked_typ; + scope_input_thunked : bool; + (* For reentrant variables: if true, the type t of the field has been + changed to (unit -> t). Otherwise, the type was already a function and + wasn't changed so no additional wrapping will be needed *) } type 'm scope_ref = @@ -193,19 +197,30 @@ let collapse_similar_outcomes (type m) (excepts : m Scopelang.Ast.expr list) : in excepts -let thunk_scope_arg ~is_func io_in e = +let input_var_needs_thunking typ io_in = (* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so that we can put them in default terms at the initialisation of the function body, allowing an empty error to recover the default value. *) - let silent_var = Var.make "_" in - let pos = Mark.get io_in in - match Mark.remove io_in with - | Runtime.NoInput -> invalid_arg "thunk_scope_arg" - | Runtime.OnlyInput -> Expr.eerroronempty e (Mark.get e) - | Runtime.Reentrant -> - (* we don't need to thunk expressions that are already functions *) - if is_func then e - else Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos + match Mark.remove io_in.Desugared.Ast.io_input, typ with + | Runtime.Reentrant, TArrow _ -> + false (* we don't need to thunk expressions that are already functions *) + | Runtime.Reentrant, _ -> true + | _ -> false + +let input_var_typ typ io_in = + let pos = Mark.get io_in.Desugared.Ast.io_input in + if input_var_needs_thunking typ io_in then + TArrow ([TLit TUnit, pos], (typ, pos)), pos + else typ, pos + +let thunk_scope_arg var_ctx e = + match var_ctx.scope_input_io, var_ctx.scope_input_thunked with + | (Runtime.NoInput, _), _ -> invalid_arg "thunk_scope_arg" + | (Runtime.OnlyInput, _), false -> Expr.eerroronempty e (Mark.get e) + | (Runtime.Reentrant, _), false -> e + | (Runtime.Reentrant, pos), true -> + Expr.make_abs [| Var.make "_" |] e [TLit TUnit, pos] pos + | _ -> assert false let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : 'm Ast.expr boxed = @@ -246,23 +261,27 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : let in_var_map = ScopeVar.Map.merge (fun var_name (str_field : scope_input_var_ctx option) expr -> - let expr = - match str_field, expr with - | Some { scope_input_io = Reentrant, _; _ }, None -> - Some (Expr.unbox (Expr.eemptyerror (mark_tany m pos))) - | _ -> expr - in match str_field, expr with - | None, None -> None + | None, None -> assert false + | Some ({ scope_input_io = Reentrant, iopos; _ } as var_ctx), None -> + let ty0 = + match var_ctx.scope_input_typ with + | TArrow ([_], ty) -> ty + | _ -> assert false + (* reentrant field must be thunked with correct function type at + this point *) + in + Some + ( var_ctx.scope_input_name, + Expr.make_abs + [| Var.make "_" |] + (Expr.eemptyerror (Expr.with_ty m ty0)) + [TAny, iopos] + pos ) | Some var_ctx, Some e -> Some ( var_ctx.scope_input_name, - thunk_scope_arg - ~is_func: - (match var_ctx.scope_input_typ with - | TArrow _ -> true - | _ -> false) - var_ctx.scope_input_io (translate_expr ctx e) ) + thunk_scope_arg var_ctx (translate_expr ctx e) ) | Some var_ctx, None -> Message.raise_multispanned_error [ @@ -662,9 +681,14 @@ let translate_rule }) [sigma_name, pos_sigma; a_name] in - let is_func = match Mark.remove tau with TArrow _ -> true | _ -> false in let thunked_or_nonempty_new_e = - thunk_scope_arg ~is_func a_io.Desugared.Ast.io_input new_e + match a_io.Desugared.Ast.io_input with + | Runtime.NoInput, _ -> assert false + | Runtime.OnlyInput, _ -> Expr.eerroronempty new_e (Mark.get new_e) + | Runtime.Reentrant, pos -> ( + match Mark.remove tau with + | TArrow _ -> new_e + | _ -> Expr.thunk_term new_e (Expr.with_pos pos (Mark.get new_e))) in ( (fun next -> Bindlib.box_apply2 @@ -673,13 +697,7 @@ let translate_rule { scope_let_next = next; scope_let_pos = Mark.get a_name; - scope_let_typ = - (match Mark.remove a_io.io_input with - | NoInput -> failwith "should not happen" - | OnlyInput -> tau - | Reentrant -> - if is_func then tau - else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos); + scope_let_typ = input_var_typ (Mark.remove tau) a_io; scope_let_expr = thunked_or_nonempty_new_e; scope_let_kind = SubScopeVarDefinition; }) @@ -922,17 +940,6 @@ let translate_rules (Expr.Box.lift return_exp)), new_ctx ) -let input_var_typ typ io_in = - match io_in.Desugared.Ast.io_input with - | Runtime.OnlyInput, pos -> typ, pos - | Runtime.Reentrant, pos -> ( - match typ with - | TArrow _ -> typ, pos - | _ -> - ( TArrow ([TLit TUnit, pos], (typ, pos)), - pos )) - | Runtime.NoInput, _ -> invalid_arg "input_var_typ" - (* From a scope declaration and definitions, create the corresponding scope body wrapped in the appropriate call convention. *) let translate_scope_decl @@ -1032,7 +1039,8 @@ let translate_scope_decl scope_let_kind = DestructuringInputStruct; scope_let_next = next; scope_let_pos = pos_sigma; - scope_let_typ = input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io; + scope_let_typ = + input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io; scope_let_expr = ( EStructAccess { name = scope_input_struct_name; e = r; field }, @@ -1045,11 +1053,11 @@ let translate_scope_decl in Bindlib.box_apply (fun scope_body_expr -> - { - scope_body_expr; - scope_body_input_struct = scope_input_struct_name; - scope_body_output_struct = scope_return_struct_name; - }) + { + scope_body_expr; + scope_body_input_struct = scope_input_struct_name; + scope_body_output_struct = scope_return_struct_name; + }) (Bindlib.bind_var scope_input_var (input_destructurings rules_with_return_expr)) @@ -1097,7 +1105,10 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = { scope_input_name = StructField.fresh (s, Mark.get info); scope_input_io = vis.Desugared.Ast.io_input; - scope_input_typ = Mark.remove (input_var_typ (Mark.remove typ) vis); + scope_input_typ = + Mark.remove (input_var_typ (Mark.remove typ) vis); + scope_input_thunked = + input_var_needs_thunking (Mark.remove typ) vis; }) scope.Scopelang.Ast.scope_sig in @@ -1141,18 +1152,16 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = let add_scope_in_structs scope_sigs structs = ScopeName.Map.fold (fun _ scope_sig_ctx acc -> - let fields = - ScopeVar.Map.fold - (fun _ sivc acc -> - let pos = - Mark.get (StructField.get_info sivc.scope_input_name) - in - StructField.Map.add sivc.scope_input_name - (sivc.scope_input_typ, pos) - acc) - scope_sig_ctx.scope_sig_in_fields StructField.Map.empty - in - StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc) + let fields = + ScopeVar.Map.fold + (fun _ sivc acc -> + let pos = Mark.get (StructField.get_info sivc.scope_input_name) in + StructField.Map.add sivc.scope_input_name + (sivc.scope_input_typ, pos) + acc) + scope_sig_ctx.scope_sig_in_fields StructField.Map.empty + in + StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc) scope_sigs.scope_sigs structs in let rec gather_module_in_structs acc sctx = @@ -1171,7 +1180,6 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = (gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules); } in - Message.emit_debug "STRUCTS: %a" (StructName.Map.format_keys ~pp_sep:Format.pp_print_space) decl_ctx.ctx_structs; let top_ctx = let toplevel_vars = TopdefName.Map.mapi diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index cc76ae78..68d790c7 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -425,7 +425,7 @@ let rec evaluate_operator (* /S\ dark magic here. This relies both on internals of [Lcalc.to_ocaml] *and* of the OCaml runtime *) let rec runtime_to_val : - type d e. + type d e. (decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) -> decl_ctx -> 'm mark -> @@ -481,7 +481,7 @@ let rec runtime_to_val : | TAny -> assert false and val_to_runtime : - type d e . + type d e. (decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) -> decl_ctx -> typ -> @@ -542,8 +542,7 @@ and val_to_runtime : curry [] targs | _ -> Message.raise_internal_error - "Could not convert value of type %a to runtime: %a" - (Print.typ ctx) ty + "Could not convert value of type %a to runtime: %a" (Print.typ ctx) ty Expr.format v let rec evaluate_expr : diff --git a/tests/test_modules/good/mod_def_context.catala_en b/tests/test_modules/good/mod_def_context.catala_en new file mode 100644 index 00000000..abf27d47 --- /dev/null +++ b/tests/test_modules/good/mod_def_context.catala_en @@ -0,0 +1,125 @@ +Testing adequacy of the scope calling convention with various types of +parameters (reentrant, functions ...) ; and different calls (through subscopes +or direct scope calls). The main part of the test is in `mod_use_context`. + +> Module Mod_def_context + +```catala-metadata +declaration scope S: + context output ci content integer + context output cm content money + context output cfun1 content decimal depends on x content integer + input output cfun2 content decimal depends on x content integer +``` + +```catala +scope S: + definition ci equals 0 + definition cm equals $0 + definition cfun1 of x equals x / 2 +``` + +Now testing direct calls within the same module + +```catala +declaration third content decimal + depends on x content integer + equals x / 3 + +declaration quarter content decimal + depends on x content integer + equals x / 4 +``` + +```catala +declaration scope Stest: + output o1 content S + output o2 content S + output x11 content decimal + output x12 content decimal + output x21 content decimal + output x22 content decimal + +scope Stest: + definition o1 equals + output of S with { -- cfun2: quarter } + definition o2 equals + output of S with { + -- ci: 1 + -- cm: $1 + -- cfun1: third + -- cfun2: quarter + } + definition x11 equals o1.cfun1 of 24 + definition x12 equals o1.cfun2 of 24 + definition x21 equals o2.cfun1 of 24 + definition x22 equals o2.cfun2 of 24 +``` + +```catala-test-inline +$ catala interpret -s Stest +[RESULT] Computation successful! Results: +[RESULT] +o1 = S { -- ci: 0 -- cm: $0.00 -- cfun1: -- cfun2: } +[RESULT] +o2 = S { -- ci: 1 -- cm: $1.00 -- cfun1: -- cfun2: } +[RESULT] x11 = 12.0 +[RESULT] x12 = 6.0 +[RESULT] x21 = 8.0 +[RESULT] x22 = 6.0 +``` + +### Testing subscopes (with and without context override) + +```catala +declaration scope TestSubDefault: + sub scope S + output ci content integer + output cm content money + output x11 content decimal + output x12 content decimal + +scope TestSubDefault: + definition sub.cfun2 of x equals quarter of x + definition ci equals sub.ci + definition cm equals sub.cm + definition x11 equals sub.cfun1 of 24 + definition x12 equals sub.cfun2 of 24 +``` + +```catala-test-inline +$ catala interpret -s TestSubDefault +[RESULT] Computation successful! Results: +[RESULT] ci = 0 +[RESULT] cm = $0.00 +[RESULT] x11 = 12.0 +[RESULT] x12 = 6.0 +``` + +```catala +declaration scope TestSubOverride: + sub scope S + output ci content integer + output cm content money + output x21 content decimal + output x22 content decimal + +scope TestSubOverride: + definition sub.ci equals 1 + definition sub.cm equals $1 + definition sub.cfun1 of x equals third of x + definition sub.cfun2 of x equals quarter of x + definition ci equals sub.ci + definition cm equals sub.cm + definition x21 equals sub.cfun1 of 24 + definition x22 equals sub.cfun2 of 24 +``` + +```catala-test-inline +$ catala interpret -s TestSubOverride +[RESULT] Computation successful! Results: +[RESULT] ci = 1 +[RESULT] cm = $1.00 +[RESULT] x21 = 8.0 +[RESULT] x22 = 6.0 +```