Fix handling of context vars with all call cases

This commit is contained in:
Louis Gesbert 2023-10-12 14:41:57 +02:00
parent af8ff472a5
commit 61ec34e3d9
3 changed files with 199 additions and 67 deletions

View File

@ -27,6 +27,10 @@ type scope_input_var_ctx = {
scope_input_name : StructField.t;
scope_input_io : Runtime.io_input Mark.pos;
scope_input_typ : naked_typ;
scope_input_thunked : bool;
(* For reentrant variables: if true, the type t of the field has been
changed to (unit -> t). Otherwise, the type was already a function and
wasn't changed so no additional wrapping will be needed *)
}
type 'm scope_ref =
@ -193,19 +197,30 @@ let collapse_similar_outcomes (type m) (excepts : m Scopelang.Ast.expr list) :
in
excepts
let thunk_scope_arg ~is_func io_in e =
let input_var_needs_thunking typ io_in =
(* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so
that we can put them in default terms at the initialisation of the function
body, allowing an empty error to recover the default value. *)
let silent_var = Var.make "_" in
let pos = Mark.get io_in in
match Mark.remove io_in with
| Runtime.NoInput -> invalid_arg "thunk_scope_arg"
| Runtime.OnlyInput -> Expr.eerroronempty e (Mark.get e)
| Runtime.Reentrant ->
(* we don't need to thunk expressions that are already functions *)
if is_func then e
else Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
match Mark.remove io_in.Desugared.Ast.io_input, typ with
| Runtime.Reentrant, TArrow _ ->
false (* we don't need to thunk expressions that are already functions *)
| Runtime.Reentrant, _ -> true
| _ -> false
let input_var_typ typ io_in =
let pos = Mark.get io_in.Desugared.Ast.io_input in
if input_var_needs_thunking typ io_in then
TArrow ([TLit TUnit, pos], (typ, pos)), pos
else typ, pos
let thunk_scope_arg var_ctx e =
match var_ctx.scope_input_io, var_ctx.scope_input_thunked with
| (Runtime.NoInput, _), _ -> invalid_arg "thunk_scope_arg"
| (Runtime.OnlyInput, _), false -> Expr.eerroronempty e (Mark.get e)
| (Runtime.Reentrant, _), false -> e
| (Runtime.Reentrant, pos), true ->
Expr.make_abs [| Var.make "_" |] e [TLit TUnit, pos] pos
| _ -> assert false
let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
'm Ast.expr boxed =
@ -246,23 +261,27 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
let in_var_map =
ScopeVar.Map.merge
(fun var_name (str_field : scope_input_var_ctx option) expr ->
let expr =
match str_field, expr with
| Some { scope_input_io = Reentrant, _; _ }, None ->
Some (Expr.unbox (Expr.eemptyerror (mark_tany m pos)))
| _ -> expr
in
match str_field, expr with
| None, None -> None
| None, None -> assert false
| Some ({ scope_input_io = Reentrant, iopos; _ } as var_ctx), None ->
let ty0 =
match var_ctx.scope_input_typ with
| TArrow ([_], ty) -> ty
| _ -> assert false
(* reentrant field must be thunked with correct function type at
this point *)
in
Some
( var_ctx.scope_input_name,
Expr.make_abs
[| Var.make "_" |]
(Expr.eemptyerror (Expr.with_ty m ty0))
[TAny, iopos]
pos )
| Some var_ctx, Some e ->
Some
( var_ctx.scope_input_name,
thunk_scope_arg
~is_func:
(match var_ctx.scope_input_typ with
| TArrow _ -> true
| _ -> false)
var_ctx.scope_input_io (translate_expr ctx e) )
thunk_scope_arg var_ctx (translate_expr ctx e) )
| Some var_ctx, None ->
Message.raise_multispanned_error
[
@ -662,9 +681,14 @@ let translate_rule
})
[sigma_name, pos_sigma; a_name]
in
let is_func = match Mark.remove tau with TArrow _ -> true | _ -> false in
let thunked_or_nonempty_new_e =
thunk_scope_arg ~is_func a_io.Desugared.Ast.io_input new_e
match a_io.Desugared.Ast.io_input with
| Runtime.NoInput, _ -> assert false
| Runtime.OnlyInput, _ -> Expr.eerroronempty new_e (Mark.get new_e)
| Runtime.Reentrant, pos -> (
match Mark.remove tau with
| TArrow _ -> new_e
| _ -> Expr.thunk_term new_e (Expr.with_pos pos (Mark.get new_e)))
in
( (fun next ->
Bindlib.box_apply2
@ -673,13 +697,7 @@ let translate_rule
{
scope_let_next = next;
scope_let_pos = Mark.get a_name;
scope_let_typ =
(match Mark.remove a_io.io_input with
| NoInput -> failwith "should not happen"
| OnlyInput -> tau
| Reentrant ->
if is_func then tau
else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos);
scope_let_typ = input_var_typ (Mark.remove tau) a_io;
scope_let_expr = thunked_or_nonempty_new_e;
scope_let_kind = SubScopeVarDefinition;
})
@ -922,17 +940,6 @@ let translate_rules
(Expr.Box.lift return_exp)),
new_ctx )
let input_var_typ typ io_in =
match io_in.Desugared.Ast.io_input with
| Runtime.OnlyInput, pos -> typ, pos
| Runtime.Reentrant, pos -> (
match typ with
| TArrow _ -> typ, pos
| _ ->
( TArrow ([TLit TUnit, pos], (typ, pos)),
pos ))
| Runtime.NoInput, _ -> invalid_arg "input_var_typ"
(* From a scope declaration and definitions, create the corresponding scope body
wrapped in the appropriate call convention. *)
let translate_scope_decl
@ -1032,7 +1039,8 @@ let translate_scope_decl
scope_let_kind = DestructuringInputStruct;
scope_let_next = next;
scope_let_pos = pos_sigma;
scope_let_typ = input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
scope_let_typ =
input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
scope_let_expr =
( EStructAccess
{ name = scope_input_struct_name; e = r; field },
@ -1045,11 +1053,11 @@ let translate_scope_decl
in
Bindlib.box_apply
(fun scope_body_expr ->
{
scope_body_expr;
scope_body_input_struct = scope_input_struct_name;
scope_body_output_struct = scope_return_struct_name;
})
{
scope_body_expr;
scope_body_input_struct = scope_input_struct_name;
scope_body_output_struct = scope_return_struct_name;
})
(Bindlib.bind_var scope_input_var
(input_destructurings rules_with_return_expr))
@ -1097,7 +1105,10 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
{
scope_input_name = StructField.fresh (s, Mark.get info);
scope_input_io = vis.Desugared.Ast.io_input;
scope_input_typ = Mark.remove (input_var_typ (Mark.remove typ) vis);
scope_input_typ =
Mark.remove (input_var_typ (Mark.remove typ) vis);
scope_input_thunked =
input_var_needs_thunking (Mark.remove typ) vis;
})
scope.Scopelang.Ast.scope_sig
in
@ -1141,18 +1152,16 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
let add_scope_in_structs scope_sigs structs =
ScopeName.Map.fold
(fun _ scope_sig_ctx acc ->
let fields =
ScopeVar.Map.fold
(fun _ sivc acc ->
let pos =
Mark.get (StructField.get_info sivc.scope_input_name)
in
StructField.Map.add sivc.scope_input_name
(sivc.scope_input_typ, pos)
acc)
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
in
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
let fields =
ScopeVar.Map.fold
(fun _ sivc acc ->
let pos = Mark.get (StructField.get_info sivc.scope_input_name) in
StructField.Map.add sivc.scope_input_name
(sivc.scope_input_typ, pos)
acc)
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
in
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
scope_sigs.scope_sigs structs
in
let rec gather_module_in_structs acc sctx =
@ -1171,7 +1180,6 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
(gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules);
}
in
Message.emit_debug "STRUCTS: %a" (StructName.Map.format_keys ~pp_sep:Format.pp_print_space) decl_ctx.ctx_structs;
let top_ctx =
let toplevel_vars =
TopdefName.Map.mapi

View File

@ -425,7 +425,7 @@ let rec evaluate_operator
(* /S\ dark magic here. This relies both on internals of [Lcalc.to_ocaml] *and*
of the OCaml runtime *)
let rec runtime_to_val :
type d e.
type d e.
(decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) ->
decl_ctx ->
'm mark ->
@ -481,7 +481,7 @@ let rec runtime_to_val :
| TAny -> assert false
and val_to_runtime :
type d e .
type d e.
(decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) ->
decl_ctx ->
typ ->
@ -542,8 +542,7 @@ and val_to_runtime :
curry [] targs
| _ ->
Message.raise_internal_error
"Could not convert value of type %a to runtime: %a"
(Print.typ ctx) ty
"Could not convert value of type %a to runtime: %a" (Print.typ ctx) ty
Expr.format v
let rec evaluate_expr :

View File

@ -0,0 +1,125 @@
Testing adequacy of the scope calling convention with various types of
parameters (reentrant, functions ...) ; and different calls (through subscopes
or direct scope calls). The main part of the test is in `mod_use_context`.
> Module Mod_def_context
```catala-metadata
declaration scope S:
context output ci content integer
context output cm content money
context output cfun1 content decimal depends on x content integer
input output cfun2 content decimal depends on x content integer
```
```catala
scope S:
definition ci equals 0
definition cm equals $0
definition cfun1 of x equals x / 2
```
Now testing direct calls within the same module
```catala
declaration third content decimal
depends on x content integer
equals x / 3
declaration quarter content decimal
depends on x content integer
equals x / 4
```
```catala
declaration scope Stest:
output o1 content S
output o2 content S
output x11 content decimal
output x12 content decimal
output x21 content decimal
output x22 content decimal
scope Stest:
definition o1 equals
output of S with { -- cfun2: quarter }
definition o2 equals
output of S with {
-- ci: 1
-- cm: $1
-- cfun1: third
-- cfun2: quarter
}
definition x11 equals o1.cfun1 of 24
definition x12 equals o1.cfun2 of 24
definition x21 equals o2.cfun1 of 24
definition x22 equals o2.cfun2 of 24
```
```catala-test-inline
$ catala interpret -s Stest
[RESULT] Computation successful! Results:
[RESULT]
o1 = S { -- ci: 0 -- cm: $0.00 -- cfun1: <function> -- cfun2: <function> }
[RESULT]
o2 = S { -- ci: 1 -- cm: $1.00 -- cfun1: <function> -- cfun2: <function> }
[RESULT] x11 = 12.0
[RESULT] x12 = 6.0
[RESULT] x21 = 8.0
[RESULT] x22 = 6.0
```
### Testing subscopes (with and without context override)
```catala
declaration scope TestSubDefault:
sub scope S
output ci content integer
output cm content money
output x11 content decimal
output x12 content decimal
scope TestSubDefault:
definition sub.cfun2 of x equals quarter of x
definition ci equals sub.ci
definition cm equals sub.cm
definition x11 equals sub.cfun1 of 24
definition x12 equals sub.cfun2 of 24
```
```catala-test-inline
$ catala interpret -s TestSubDefault
[RESULT] Computation successful! Results:
[RESULT] ci = 0
[RESULT] cm = $0.00
[RESULT] x11 = 12.0
[RESULT] x12 = 6.0
```
```catala
declaration scope TestSubOverride:
sub scope S
output ci content integer
output cm content money
output x21 content decimal
output x22 content decimal
scope TestSubOverride:
definition sub.ci equals 1
definition sub.cm equals $1
definition sub.cfun1 of x equals third of x
definition sub.cfun2 of x equals quarter of x
definition ci equals sub.ci
definition cm equals sub.cm
definition x21 equals sub.cfun1 of 24
definition x22 equals sub.cfun2 of 24
```
```catala-test-inline
$ catala interpret -s TestSubOverride
[RESULT] Computation successful! Results:
[RESULT] ci = 1
[RESULT] cm = $1.00
[RESULT] x21 = 8.0
[RESULT] x22 = 6.0
```