Scope calls: proper handling of context vars

Also proper error messages on bad scope input specifications.

* Still needs more tests
This commit is contained in:
Louis Gesbert 2022-10-24 18:25:20 +02:00
parent b19a7660fc
commit 73173285e4
4 changed files with 117 additions and 37 deletions

View File

@ -31,6 +31,10 @@ type 'm scope_sig_ctx = {
(** Var representing the scope input inside the scope func *)
scope_sig_input_struct : StructName.t; (** Scope input *)
scope_sig_output_struct : StructName.t; (** Scope output *)
scope_sig_in_fields :
(StructFieldName.t * Ast.io_input Marked.pos) ScopeVarMap.t;
(** Mapping between the input scope variables and the input struct fields.
The boolean is true for 'context' variables which need to be thunked. *)
}
type 'm scope_sigs_ctx = 'm scope_sig_ctx ScopeMap.t
@ -142,6 +146,14 @@ let collapse_similar_outcomes (type m) (excepts : m Ast.expr list) :
in
excepts
let thunk_scope_arg io_in e =
let silent_var = Var.make "_" in
let pos = Marked.get_mark io_in in
match Marked.unmark io_in with
| Ast.NoInput -> invalid_arg "thunk_scope_arg"
| Ast.OnlyInput -> Expr.eerroronempty e (Marked.get_mark e)
| Ast.Reentrant -> Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
let rec translate_expr (ctx : 'm ctx) (e : 'm Ast.expr) :
'm Dcalc.Ast.expr boxed =
let m = Marked.get_mark e in
@ -228,23 +240,46 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Ast.expr) :
| EScopeCall (sc_name, fields) ->
let pos = Expr.mark_pos m in
let sc_sig = ScopeMap.find sc_name ctx.scopes_parameters in
let struct_def = StructMap.find sc_sig.scope_sig_input_struct ctx.structs in
let struct_fields =
(* Fixme: the correspondance of the two lists is fragile (see also the
conversion of [Call] *)
List.map2
(fun (sc_var, e) (fld_name, _ty) ->
(* pretty weak check, but better than nothing for now *)
assert (
Marked.unmark (ScopeVar.get_info sc_var) ^ "_in"
= Marked.unmark (StructFieldName.get_info fld_name));
translate_expr ctx e)
(ScopeVarMap.bindings fields)
struct_def
let in_var_map =
ScopeVarMap.merge
(fun var_name str_field expr ->
let expr =
match str_field, expr with
| Some (_, (Ast.Reentrant, _)), None ->
Some (Expr.unbox (Expr.elit LEmptyError (mark_tany m pos)))
| _ -> expr
in
match str_field, expr with
| None, None -> None
| Some (fld, io_in), Some e ->
Some (fld, thunk_scope_arg io_in (translate_expr ctx e))
| Some (fld, _), None ->
Errors.raise_multispanned_error
[
None, pos;
( Some "Declaration of the Missing input variable",
Marked.get_mark (StructFieldName.get_info fld) );
]
"Definition of input variable '%a' missing in this scope call"
ScopeVar.format_t var_name
| None, Some _ ->
Errors.raise_multispanned_error
[
None, pos;
( Some "Declaration of scope '%a'",
Marked.get_mark (ScopeName.get_info sc_name) );
]
"Unknown input variable '%a' in scope call of '%a'"
ScopeVar.format_t var_name ScopeName.format_t sc_name)
sc_sig.scope_sig_in_fields fields
in
let field_map =
ScopeVarMap.fold
(fun _ (fld, e) acc -> StructFieldMap.add fld e acc)
in_var_map StructFieldMap.empty
in
let arg_struct =
Expr.etuple struct_fields (Some sc_sig.scope_sig_input_struct)
(mark_tany m pos)
Expr.make_struct field_map sc_sig.scope_sig_input_struct (mark_tany m pos)
in
Expr.eapp
(Expr.evar sc_sig.scope_sig_scope_var (mark_tany m pos))
@ -418,7 +453,6 @@ let translate_rule
tau,
a_io,
e ) ->
let _pos_mark, pos_mark_as = pos_mark_mk e in
let a_name =
Marked.map_under_mark
(fun str ->
@ -431,16 +465,7 @@ let translate_rule
(VarDef (Marked.unmark tau))
[sigma_name, pos_sigma; a_name]
in
let silent_var = Var.make "_" in
let thunked_or_nonempty_new_e =
match Marked.unmark a_io.io_input with
| NoInput -> failwith "should not happen"
| OnlyInput -> Expr.eerroronempty new_e (pos_mark_as subs_var)
| Reentrant ->
Expr.make_abs [| silent_var |] new_e
[TLit TUnit, var_def_pos]
var_def_pos
in
let thunked_or_nonempty_new_e = thunk_scope_arg a_io.Ast.io_input new_e in
( (fun next ->
Bindlib.box_apply2
(fun next thunked_or_nonempty_new_e ->
@ -772,18 +797,15 @@ let translate_scope_decl
scope_input_variables
(next, List.length scope_input_variables - 1))
in
let scope_input_struct_fields =
let field_map =
List.map
(fun (var_ctx, dvar) ->
let struct_field_name =
StructFieldName.fresh (Bindlib.name_of dvar ^ "_in", pos_sigma)
in
struct_field_name, input_var_typ var_ctx)
(fun (var_ctx, _) ->
let var = var_ctx.scope_var_name in
let field, _ = ScopeVarMap.find var scope_sig.scope_sig_in_fields in
field, input_var_typ var_ctx)
scope_input_variables
in
let new_struct_ctx =
StructMap.singleton scope_input_struct_name scope_input_struct_fields
in
let new_struct_ctx = StructMap.singleton scope_input_struct_name field_map in
( Bindlib.box_apply
(fun scope_body_expr ->
{
@ -819,6 +841,19 @@ let translate_program (prgm : 'm Ast.program) : 'm Dcalc.Ast.program =
(fun s -> s ^ "_in")
(ScopeName.get_info scope_name))
in
let scope_sig_in_fields =
ScopeVarMap.filter_map
(fun dvar (_, vis) ->
match Marked.unmark vis.Ast.io_input with
| NoInput -> None
| OnlyInput | Reentrant ->
let info = ScopeVar.get_info dvar in
let s = Marked.unmark info ^ "_in" in
Some
( StructFieldName.fresh (s, Marked.get_mark info),
vis.Ast.io_input ))
scope.scope_sig
in
{
scope_sig_local_vars =
List.map
@ -833,11 +868,12 @@ let translate_program (prgm : 'm Ast.program) : 'm Dcalc.Ast.program =
scope_sig_input_var = scope_input_var;
scope_sig_input_struct = scope_input_struct_name;
scope_sig_output_struct = scope_return_struct_name;
scope_sig_in_fields;
})
prgm.program_scopes
in
(* the resulting expression is the list of definitions of all the scopes,
ending with the top-level scope. The decl_ctx is allocated in left-to-right
ending with the top-level scope. The decl_ctx is filled in left-to-right
order, then the chained scopes aggregated from the right. *)
let rec translate_scopes decl_ctx = function
| scope_name :: next_scopes ->

View File

@ -851,7 +851,17 @@ let make_tuple el structname m0 =
let m =
fold_marks
(fun posl -> List.hd posl)
(fun ml -> TTuple (List.map (fun t -> t.ty) ml), (List.hd ml).pos)
(fun ml ->
let pos = (List.hd ml).pos in
match structname with
| Some n -> TStruct n, pos
| None -> TTuple (List.map (fun t -> t.ty) ml), pos)
(List.map (fun e -> Marked.get_mark e) el)
in
etuple el structname m
let make_struct fieldmap structname m =
let fields =
List.rev (StructFieldMap.fold (fun _ e acc -> e :: acc) fieldmap [])
in
make_tuple fields (Some structname) m

View File

@ -296,6 +296,14 @@ val make_tuple :
(** Builds a tuple; the mark argument is only used as witness and for position
when building 0-uples *)
val make_struct :
(([< dcalc | lcalc ] as 'a), 'm mark) boxed_gexpr StructFieldMap.t ->
StructName.t ->
'm mark ->
('a, 'm mark) boxed_gexpr
(** Builds the tuple of values for the given struct with proper ordering,
assuming the structfieldmap contains the fields defined for structname *)
(** {2 Transformations} *)
val remove_logging_calls : ('a any, 't) gexpr -> ('a, 't) boxed_gexpr

View File

@ -0,0 +1,26 @@
```catala
declaration scope Toto:
context bar content integer
output foo content integer
scope Toto:
definition bar equals 1
definition foo equals 1212 + bar
declaration scope Titi:
output fizz content Toto
output fuzz content Toto
toto scope Toto
scope Titi:
definition toto.bar equals 44
definition fizz equals Toto of {}
definition fuzz equals Toto of {--bar: 111}
```
```catala-test-inline
$ catala Interpret -s Titi
[RESULT] Computation successful! Results:
[RESULT] fizz = Toto {"foo"= 1213}
[RESULT] fuzz = Toto {"foo"= 1323}
```