mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Scope calls: proper handling of context vars
Also proper error messages on bad scope input specifications. * Still needs more tests
This commit is contained in:
parent
b19a7660fc
commit
73173285e4
@ -31,6 +31,10 @@ type 'm scope_sig_ctx = {
|
||||
(** Var representing the scope input inside the scope func *)
|
||||
scope_sig_input_struct : StructName.t; (** Scope input *)
|
||||
scope_sig_output_struct : StructName.t; (** Scope output *)
|
||||
scope_sig_in_fields :
|
||||
(StructFieldName.t * Ast.io_input Marked.pos) ScopeVarMap.t;
|
||||
(** Mapping between the input scope variables and the input struct fields.
|
||||
The boolean is true for 'context' variables which need to be thunked. *)
|
||||
}
|
||||
|
||||
type 'm scope_sigs_ctx = 'm scope_sig_ctx ScopeMap.t
|
||||
@ -142,6 +146,14 @@ let collapse_similar_outcomes (type m) (excepts : m Ast.expr list) :
|
||||
in
|
||||
excepts
|
||||
|
||||
let thunk_scope_arg io_in e =
|
||||
let silent_var = Var.make "_" in
|
||||
let pos = Marked.get_mark io_in in
|
||||
match Marked.unmark io_in with
|
||||
| Ast.NoInput -> invalid_arg "thunk_scope_arg"
|
||||
| Ast.OnlyInput -> Expr.eerroronempty e (Marked.get_mark e)
|
||||
| Ast.Reentrant -> Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
|
||||
|
||||
let rec translate_expr (ctx : 'm ctx) (e : 'm Ast.expr) :
|
||||
'm Dcalc.Ast.expr boxed =
|
||||
let m = Marked.get_mark e in
|
||||
@ -228,23 +240,46 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Ast.expr) :
|
||||
| EScopeCall (sc_name, fields) ->
|
||||
let pos = Expr.mark_pos m in
|
||||
let sc_sig = ScopeMap.find sc_name ctx.scopes_parameters in
|
||||
let struct_def = StructMap.find sc_sig.scope_sig_input_struct ctx.structs in
|
||||
let struct_fields =
|
||||
(* Fixme: the correspondance of the two lists is fragile (see also the
|
||||
conversion of [Call] *)
|
||||
List.map2
|
||||
(fun (sc_var, e) (fld_name, _ty) ->
|
||||
(* pretty weak check, but better than nothing for now *)
|
||||
assert (
|
||||
Marked.unmark (ScopeVar.get_info sc_var) ^ "_in"
|
||||
= Marked.unmark (StructFieldName.get_info fld_name));
|
||||
translate_expr ctx e)
|
||||
(ScopeVarMap.bindings fields)
|
||||
struct_def
|
||||
let in_var_map =
|
||||
ScopeVarMap.merge
|
||||
(fun var_name str_field expr ->
|
||||
let expr =
|
||||
match str_field, expr with
|
||||
| Some (_, (Ast.Reentrant, _)), None ->
|
||||
Some (Expr.unbox (Expr.elit LEmptyError (mark_tany m pos)))
|
||||
| _ -> expr
|
||||
in
|
||||
match str_field, expr with
|
||||
| None, None -> None
|
||||
| Some (fld, io_in), Some e ->
|
||||
Some (fld, thunk_scope_arg io_in (translate_expr ctx e))
|
||||
| Some (fld, _), None ->
|
||||
Errors.raise_multispanned_error
|
||||
[
|
||||
None, pos;
|
||||
( Some "Declaration of the Missing input variable",
|
||||
Marked.get_mark (StructFieldName.get_info fld) );
|
||||
]
|
||||
"Definition of input variable '%a' missing in this scope call"
|
||||
ScopeVar.format_t var_name
|
||||
| None, Some _ ->
|
||||
Errors.raise_multispanned_error
|
||||
[
|
||||
None, pos;
|
||||
( Some "Declaration of scope '%a'",
|
||||
Marked.get_mark (ScopeName.get_info sc_name) );
|
||||
]
|
||||
"Unknown input variable '%a' in scope call of '%a'"
|
||||
ScopeVar.format_t var_name ScopeName.format_t sc_name)
|
||||
sc_sig.scope_sig_in_fields fields
|
||||
in
|
||||
let field_map =
|
||||
ScopeVarMap.fold
|
||||
(fun _ (fld, e) acc -> StructFieldMap.add fld e acc)
|
||||
in_var_map StructFieldMap.empty
|
||||
in
|
||||
let arg_struct =
|
||||
Expr.etuple struct_fields (Some sc_sig.scope_sig_input_struct)
|
||||
(mark_tany m pos)
|
||||
Expr.make_struct field_map sc_sig.scope_sig_input_struct (mark_tany m pos)
|
||||
in
|
||||
Expr.eapp
|
||||
(Expr.evar sc_sig.scope_sig_scope_var (mark_tany m pos))
|
||||
@ -418,7 +453,6 @@ let translate_rule
|
||||
tau,
|
||||
a_io,
|
||||
e ) ->
|
||||
let _pos_mark, pos_mark_as = pos_mark_mk e in
|
||||
let a_name =
|
||||
Marked.map_under_mark
|
||||
(fun str ->
|
||||
@ -431,16 +465,7 @@ let translate_rule
|
||||
(VarDef (Marked.unmark tau))
|
||||
[sigma_name, pos_sigma; a_name]
|
||||
in
|
||||
let silent_var = Var.make "_" in
|
||||
let thunked_or_nonempty_new_e =
|
||||
match Marked.unmark a_io.io_input with
|
||||
| NoInput -> failwith "should not happen"
|
||||
| OnlyInput -> Expr.eerroronempty new_e (pos_mark_as subs_var)
|
||||
| Reentrant ->
|
||||
Expr.make_abs [| silent_var |] new_e
|
||||
[TLit TUnit, var_def_pos]
|
||||
var_def_pos
|
||||
in
|
||||
let thunked_or_nonempty_new_e = thunk_scope_arg a_io.Ast.io_input new_e in
|
||||
( (fun next ->
|
||||
Bindlib.box_apply2
|
||||
(fun next thunked_or_nonempty_new_e ->
|
||||
@ -772,18 +797,15 @@ let translate_scope_decl
|
||||
scope_input_variables
|
||||
(next, List.length scope_input_variables - 1))
|
||||
in
|
||||
let scope_input_struct_fields =
|
||||
let field_map =
|
||||
List.map
|
||||
(fun (var_ctx, dvar) ->
|
||||
let struct_field_name =
|
||||
StructFieldName.fresh (Bindlib.name_of dvar ^ "_in", pos_sigma)
|
||||
in
|
||||
struct_field_name, input_var_typ var_ctx)
|
||||
(fun (var_ctx, _) ->
|
||||
let var = var_ctx.scope_var_name in
|
||||
let field, _ = ScopeVarMap.find var scope_sig.scope_sig_in_fields in
|
||||
field, input_var_typ var_ctx)
|
||||
scope_input_variables
|
||||
in
|
||||
let new_struct_ctx =
|
||||
StructMap.singleton scope_input_struct_name scope_input_struct_fields
|
||||
in
|
||||
let new_struct_ctx = StructMap.singleton scope_input_struct_name field_map in
|
||||
( Bindlib.box_apply
|
||||
(fun scope_body_expr ->
|
||||
{
|
||||
@ -819,6 +841,19 @@ let translate_program (prgm : 'm Ast.program) : 'm Dcalc.Ast.program =
|
||||
(fun s -> s ^ "_in")
|
||||
(ScopeName.get_info scope_name))
|
||||
in
|
||||
let scope_sig_in_fields =
|
||||
ScopeVarMap.filter_map
|
||||
(fun dvar (_, vis) ->
|
||||
match Marked.unmark vis.Ast.io_input with
|
||||
| NoInput -> None
|
||||
| OnlyInput | Reentrant ->
|
||||
let info = ScopeVar.get_info dvar in
|
||||
let s = Marked.unmark info ^ "_in" in
|
||||
Some
|
||||
( StructFieldName.fresh (s, Marked.get_mark info),
|
||||
vis.Ast.io_input ))
|
||||
scope.scope_sig
|
||||
in
|
||||
{
|
||||
scope_sig_local_vars =
|
||||
List.map
|
||||
@ -833,11 +868,12 @@ let translate_program (prgm : 'm Ast.program) : 'm Dcalc.Ast.program =
|
||||
scope_sig_input_var = scope_input_var;
|
||||
scope_sig_input_struct = scope_input_struct_name;
|
||||
scope_sig_output_struct = scope_return_struct_name;
|
||||
scope_sig_in_fields;
|
||||
})
|
||||
prgm.program_scopes
|
||||
in
|
||||
(* the resulting expression is the list of definitions of all the scopes,
|
||||
ending with the top-level scope. The decl_ctx is allocated in left-to-right
|
||||
ending with the top-level scope. The decl_ctx is filled in left-to-right
|
||||
order, then the chained scopes aggregated from the right. *)
|
||||
let rec translate_scopes decl_ctx = function
|
||||
| scope_name :: next_scopes ->
|
||||
|
@ -851,7 +851,17 @@ let make_tuple el structname m0 =
|
||||
let m =
|
||||
fold_marks
|
||||
(fun posl -> List.hd posl)
|
||||
(fun ml -> TTuple (List.map (fun t -> t.ty) ml), (List.hd ml).pos)
|
||||
(fun ml ->
|
||||
let pos = (List.hd ml).pos in
|
||||
match structname with
|
||||
| Some n -> TStruct n, pos
|
||||
| None -> TTuple (List.map (fun t -> t.ty) ml), pos)
|
||||
(List.map (fun e -> Marked.get_mark e) el)
|
||||
in
|
||||
etuple el structname m
|
||||
|
||||
let make_struct fieldmap structname m =
|
||||
let fields =
|
||||
List.rev (StructFieldMap.fold (fun _ e acc -> e :: acc) fieldmap [])
|
||||
in
|
||||
make_tuple fields (Some structname) m
|
||||
|
@ -296,6 +296,14 @@ val make_tuple :
|
||||
(** Builds a tuple; the mark argument is only used as witness and for position
|
||||
when building 0-uples *)
|
||||
|
||||
val make_struct :
|
||||
(([< dcalc | lcalc ] as 'a), 'm mark) boxed_gexpr StructFieldMap.t ->
|
||||
StructName.t ->
|
||||
'm mark ->
|
||||
('a, 'm mark) boxed_gexpr
|
||||
(** Builds the tuple of values for the given struct with proper ordering,
|
||||
assuming the structfieldmap contains the fields defined for structname *)
|
||||
|
||||
(** {2 Transformations} *)
|
||||
|
||||
val remove_logging_calls : ('a any, 't) gexpr -> ('a, 't) boxed_gexpr
|
||||
|
26
tests/test_scope/good/scope_call2.catala_en
Normal file
26
tests/test_scope/good/scope_call2.catala_en
Normal file
@ -0,0 +1,26 @@
|
||||
```catala
|
||||
declaration scope Toto:
|
||||
context bar content integer
|
||||
output foo content integer
|
||||
|
||||
scope Toto:
|
||||
definition bar equals 1
|
||||
definition foo equals 1212 + bar
|
||||
|
||||
declaration scope Titi:
|
||||
output fizz content Toto
|
||||
output fuzz content Toto
|
||||
toto scope Toto
|
||||
|
||||
scope Titi:
|
||||
definition toto.bar equals 44
|
||||
definition fizz equals Toto of {}
|
||||
definition fuzz equals Toto of {--bar: 111}
|
||||
```
|
||||
|
||||
```catala-test-inline
|
||||
$ catala Interpret -s Titi
|
||||
[RESULT] Computation successful! Results:
|
||||
[RESULT] fizz = Toto {"foo"= 1213}
|
||||
[RESULT] fuzz = Toto {"foo"= 1323}
|
||||
```
|
Loading…
Reference in New Issue
Block a user