From 73173285e4fa4c7f14c45b3dd2fe12fde2569f96 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Mon, 24 Oct 2022 18:25:20 +0200 Subject: [PATCH] Scope calls: proper handling of context vars Also proper error messages on bad scope input specifications. * Still needs more tests --- compiler/scopelang/scope_to_dcalc.ml | 108 +++++++++++++------- compiler/shared_ast/expr.ml | 12 ++- compiler/shared_ast/expr.mli | 8 ++ tests/test_scope/good/scope_call2.catala_en | 26 +++++ 4 files changed, 117 insertions(+), 37 deletions(-) create mode 100644 tests/test_scope/good/scope_call2.catala_en diff --git a/compiler/scopelang/scope_to_dcalc.ml b/compiler/scopelang/scope_to_dcalc.ml index 5db95d8e..faa4d91c 100644 --- a/compiler/scopelang/scope_to_dcalc.ml +++ b/compiler/scopelang/scope_to_dcalc.ml @@ -31,6 +31,10 @@ type 'm scope_sig_ctx = { (** Var representing the scope input inside the scope func *) scope_sig_input_struct : StructName.t; (** Scope input *) scope_sig_output_struct : StructName.t; (** Scope output *) + scope_sig_in_fields : + (StructFieldName.t * Ast.io_input Marked.pos) ScopeVarMap.t; + (** Mapping between the input scope variables and the input struct fields. + The boolean is true for 'context' variables which need to be thunked. *) } type 'm scope_sigs_ctx = 'm scope_sig_ctx ScopeMap.t @@ -142,6 +146,14 @@ let collapse_similar_outcomes (type m) (excepts : m Ast.expr list) : in excepts +let thunk_scope_arg io_in e = + let silent_var = Var.make "_" in + let pos = Marked.get_mark io_in in + match Marked.unmark io_in with + | Ast.NoInput -> invalid_arg "thunk_scope_arg" + | Ast.OnlyInput -> Expr.eerroronempty e (Marked.get_mark e) + | Ast.Reentrant -> Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos + let rec translate_expr (ctx : 'm ctx) (e : 'm Ast.expr) : 'm Dcalc.Ast.expr boxed = let m = Marked.get_mark e in @@ -228,23 +240,46 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Ast.expr) : | EScopeCall (sc_name, fields) -> let pos = Expr.mark_pos m in let sc_sig = ScopeMap.find sc_name ctx.scopes_parameters in - let struct_def = StructMap.find sc_sig.scope_sig_input_struct ctx.structs in - let struct_fields = - (* Fixme: the correspondance of the two lists is fragile (see also the - conversion of [Call] *) - List.map2 - (fun (sc_var, e) (fld_name, _ty) -> - (* pretty weak check, but better than nothing for now *) - assert ( - Marked.unmark (ScopeVar.get_info sc_var) ^ "_in" - = Marked.unmark (StructFieldName.get_info fld_name)); - translate_expr ctx e) - (ScopeVarMap.bindings fields) - struct_def + let in_var_map = + ScopeVarMap.merge + (fun var_name str_field expr -> + let expr = + match str_field, expr with + | Some (_, (Ast.Reentrant, _)), None -> + Some (Expr.unbox (Expr.elit LEmptyError (mark_tany m pos))) + | _ -> expr + in + match str_field, expr with + | None, None -> None + | Some (fld, io_in), Some e -> + Some (fld, thunk_scope_arg io_in (translate_expr ctx e)) + | Some (fld, _), None -> + Errors.raise_multispanned_error + [ + None, pos; + ( Some "Declaration of the Missing input variable", + Marked.get_mark (StructFieldName.get_info fld) ); + ] + "Definition of input variable '%a' missing in this scope call" + ScopeVar.format_t var_name + | None, Some _ -> + Errors.raise_multispanned_error + [ + None, pos; + ( Some "Declaration of scope '%a'", + Marked.get_mark (ScopeName.get_info sc_name) ); + ] + "Unknown input variable '%a' in scope call of '%a'" + ScopeVar.format_t var_name ScopeName.format_t sc_name) + sc_sig.scope_sig_in_fields fields + in + let field_map = + ScopeVarMap.fold + (fun _ (fld, e) acc -> StructFieldMap.add fld e acc) + in_var_map StructFieldMap.empty in let arg_struct = - Expr.etuple struct_fields (Some sc_sig.scope_sig_input_struct) - (mark_tany m pos) + Expr.make_struct field_map sc_sig.scope_sig_input_struct (mark_tany m pos) in Expr.eapp (Expr.evar sc_sig.scope_sig_scope_var (mark_tany m pos)) @@ -418,7 +453,6 @@ let translate_rule tau, a_io, e ) -> - let _pos_mark, pos_mark_as = pos_mark_mk e in let a_name = Marked.map_under_mark (fun str -> @@ -431,16 +465,7 @@ let translate_rule (VarDef (Marked.unmark tau)) [sigma_name, pos_sigma; a_name] in - let silent_var = Var.make "_" in - let thunked_or_nonempty_new_e = - match Marked.unmark a_io.io_input with - | NoInput -> failwith "should not happen" - | OnlyInput -> Expr.eerroronempty new_e (pos_mark_as subs_var) - | Reentrant -> - Expr.make_abs [| silent_var |] new_e - [TLit TUnit, var_def_pos] - var_def_pos - in + let thunked_or_nonempty_new_e = thunk_scope_arg a_io.Ast.io_input new_e in ( (fun next -> Bindlib.box_apply2 (fun next thunked_or_nonempty_new_e -> @@ -772,18 +797,15 @@ let translate_scope_decl scope_input_variables (next, List.length scope_input_variables - 1)) in - let scope_input_struct_fields = + let field_map = List.map - (fun (var_ctx, dvar) -> - let struct_field_name = - StructFieldName.fresh (Bindlib.name_of dvar ^ "_in", pos_sigma) - in - struct_field_name, input_var_typ var_ctx) + (fun (var_ctx, _) -> + let var = var_ctx.scope_var_name in + let field, _ = ScopeVarMap.find var scope_sig.scope_sig_in_fields in + field, input_var_typ var_ctx) scope_input_variables in - let new_struct_ctx = - StructMap.singleton scope_input_struct_name scope_input_struct_fields - in + let new_struct_ctx = StructMap.singleton scope_input_struct_name field_map in ( Bindlib.box_apply (fun scope_body_expr -> { @@ -819,6 +841,19 @@ let translate_program (prgm : 'm Ast.program) : 'm Dcalc.Ast.program = (fun s -> s ^ "_in") (ScopeName.get_info scope_name)) in + let scope_sig_in_fields = + ScopeVarMap.filter_map + (fun dvar (_, vis) -> + match Marked.unmark vis.Ast.io_input with + | NoInput -> None + | OnlyInput | Reentrant -> + let info = ScopeVar.get_info dvar in + let s = Marked.unmark info ^ "_in" in + Some + ( StructFieldName.fresh (s, Marked.get_mark info), + vis.Ast.io_input )) + scope.scope_sig + in { scope_sig_local_vars = List.map @@ -833,11 +868,12 @@ let translate_program (prgm : 'm Ast.program) : 'm Dcalc.Ast.program = scope_sig_input_var = scope_input_var; scope_sig_input_struct = scope_input_struct_name; scope_sig_output_struct = scope_return_struct_name; + scope_sig_in_fields; }) prgm.program_scopes in (* the resulting expression is the list of definitions of all the scopes, - ending with the top-level scope. The decl_ctx is allocated in left-to-right + ending with the top-level scope. The decl_ctx is filled in left-to-right order, then the chained scopes aggregated from the right. *) let rec translate_scopes decl_ctx = function | scope_name :: next_scopes -> diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index bbcf3e2c..8532843b 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -851,7 +851,17 @@ let make_tuple el structname m0 = let m = fold_marks (fun posl -> List.hd posl) - (fun ml -> TTuple (List.map (fun t -> t.ty) ml), (List.hd ml).pos) + (fun ml -> + let pos = (List.hd ml).pos in + match structname with + | Some n -> TStruct n, pos + | None -> TTuple (List.map (fun t -> t.ty) ml), pos) (List.map (fun e -> Marked.get_mark e) el) in etuple el structname m + +let make_struct fieldmap structname m = + let fields = + List.rev (StructFieldMap.fold (fun _ e acc -> e :: acc) fieldmap []) + in + make_tuple fields (Some structname) m diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index d5a2cffd..aa99bde9 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -296,6 +296,14 @@ val make_tuple : (** Builds a tuple; the mark argument is only used as witness and for position when building 0-uples *) +val make_struct : + (([< dcalc | lcalc ] as 'a), 'm mark) boxed_gexpr StructFieldMap.t -> + StructName.t -> + 'm mark -> + ('a, 'm mark) boxed_gexpr +(** Builds the tuple of values for the given struct with proper ordering, + assuming the structfieldmap contains the fields defined for structname *) + (** {2 Transformations} *) val remove_logging_calls : ('a any, 't) gexpr -> ('a, 't) boxed_gexpr diff --git a/tests/test_scope/good/scope_call2.catala_en b/tests/test_scope/good/scope_call2.catala_en new file mode 100644 index 00000000..1ba84f86 --- /dev/null +++ b/tests/test_scope/good/scope_call2.catala_en @@ -0,0 +1,26 @@ +```catala +declaration scope Toto: + context bar content integer + output foo content integer + +scope Toto: + definition bar equals 1 + definition foo equals 1212 + bar + +declaration scope Titi: + output fizz content Toto + output fuzz content Toto + toto scope Toto + +scope Titi: + definition toto.bar equals 44 + definition fizz equals Toto of {} + definition fuzz equals Toto of {--bar: 111} +``` + +```catala-test-inline +$ catala Interpret -s Titi +[RESULT] Computation successful! Results: +[RESULT] fizz = Toto {"foo"= 1213} +[RESULT] fuzz = Toto {"foo"= 1323} +```