mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Fix handling of context vars with all call cases
This commit is contained in:
parent
af8ff472a5
commit
61ec34e3d9
@ -27,6 +27,10 @@ type scope_input_var_ctx = {
|
||||
scope_input_name : StructField.t;
|
||||
scope_input_io : Runtime.io_input Mark.pos;
|
||||
scope_input_typ : naked_typ;
|
||||
scope_input_thunked : bool;
|
||||
(* For reentrant variables: if true, the type t of the field has been
|
||||
changed to (unit -> t). Otherwise, the type was already a function and
|
||||
wasn't changed so no additional wrapping will be needed *)
|
||||
}
|
||||
|
||||
type 'm scope_ref =
|
||||
@ -193,19 +197,30 @@ let collapse_similar_outcomes (type m) (excepts : m Scopelang.Ast.expr list) :
|
||||
in
|
||||
excepts
|
||||
|
||||
let thunk_scope_arg ~is_func io_in e =
|
||||
let input_var_needs_thunking typ io_in =
|
||||
(* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so
|
||||
that we can put them in default terms at the initialisation of the function
|
||||
body, allowing an empty error to recover the default value. *)
|
||||
let silent_var = Var.make "_" in
|
||||
let pos = Mark.get io_in in
|
||||
match Mark.remove io_in with
|
||||
| Runtime.NoInput -> invalid_arg "thunk_scope_arg"
|
||||
| Runtime.OnlyInput -> Expr.eerroronempty e (Mark.get e)
|
||||
| Runtime.Reentrant ->
|
||||
(* we don't need to thunk expressions that are already functions *)
|
||||
if is_func then e
|
||||
else Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
|
||||
match Mark.remove io_in.Desugared.Ast.io_input, typ with
|
||||
| Runtime.Reentrant, TArrow _ ->
|
||||
false (* we don't need to thunk expressions that are already functions *)
|
||||
| Runtime.Reentrant, _ -> true
|
||||
| _ -> false
|
||||
|
||||
let input_var_typ typ io_in =
|
||||
let pos = Mark.get io_in.Desugared.Ast.io_input in
|
||||
if input_var_needs_thunking typ io_in then
|
||||
TArrow ([TLit TUnit, pos], (typ, pos)), pos
|
||||
else typ, pos
|
||||
|
||||
let thunk_scope_arg var_ctx e =
|
||||
match var_ctx.scope_input_io, var_ctx.scope_input_thunked with
|
||||
| (Runtime.NoInput, _), _ -> invalid_arg "thunk_scope_arg"
|
||||
| (Runtime.OnlyInput, _), false -> Expr.eerroronempty e (Mark.get e)
|
||||
| (Runtime.Reentrant, _), false -> e
|
||||
| (Runtime.Reentrant, pos), true ->
|
||||
Expr.make_abs [| Var.make "_" |] e [TLit TUnit, pos] pos
|
||||
| _ -> assert false
|
||||
|
||||
let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
'm Ast.expr boxed =
|
||||
@ -246,23 +261,27 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
let in_var_map =
|
||||
ScopeVar.Map.merge
|
||||
(fun var_name (str_field : scope_input_var_ctx option) expr ->
|
||||
let expr =
|
||||
match str_field, expr with
|
||||
| Some { scope_input_io = Reentrant, _; _ }, None ->
|
||||
Some (Expr.unbox (Expr.eemptyerror (mark_tany m pos)))
|
||||
| _ -> expr
|
||||
| None, None -> assert false
|
||||
| Some ({ scope_input_io = Reentrant, iopos; _ } as var_ctx), None ->
|
||||
let ty0 =
|
||||
match var_ctx.scope_input_typ with
|
||||
| TArrow ([_], ty) -> ty
|
||||
| _ -> assert false
|
||||
(* reentrant field must be thunked with correct function type at
|
||||
this point *)
|
||||
in
|
||||
match str_field, expr with
|
||||
| None, None -> None
|
||||
Some
|
||||
( var_ctx.scope_input_name,
|
||||
Expr.make_abs
|
||||
[| Var.make "_" |]
|
||||
(Expr.eemptyerror (Expr.with_ty m ty0))
|
||||
[TAny, iopos]
|
||||
pos )
|
||||
| Some var_ctx, Some e ->
|
||||
Some
|
||||
( var_ctx.scope_input_name,
|
||||
thunk_scope_arg
|
||||
~is_func:
|
||||
(match var_ctx.scope_input_typ with
|
||||
| TArrow _ -> true
|
||||
| _ -> false)
|
||||
var_ctx.scope_input_io (translate_expr ctx e) )
|
||||
thunk_scope_arg var_ctx (translate_expr ctx e) )
|
||||
| Some var_ctx, None ->
|
||||
Message.raise_multispanned_error
|
||||
[
|
||||
@ -662,9 +681,14 @@ let translate_rule
|
||||
})
|
||||
[sigma_name, pos_sigma; a_name]
|
||||
in
|
||||
let is_func = match Mark.remove tau with TArrow _ -> true | _ -> false in
|
||||
let thunked_or_nonempty_new_e =
|
||||
thunk_scope_arg ~is_func a_io.Desugared.Ast.io_input new_e
|
||||
match a_io.Desugared.Ast.io_input with
|
||||
| Runtime.NoInput, _ -> assert false
|
||||
| Runtime.OnlyInput, _ -> Expr.eerroronempty new_e (Mark.get new_e)
|
||||
| Runtime.Reentrant, pos -> (
|
||||
match Mark.remove tau with
|
||||
| TArrow _ -> new_e
|
||||
| _ -> Expr.thunk_term new_e (Expr.with_pos pos (Mark.get new_e)))
|
||||
in
|
||||
( (fun next ->
|
||||
Bindlib.box_apply2
|
||||
@ -673,13 +697,7 @@ let translate_rule
|
||||
{
|
||||
scope_let_next = next;
|
||||
scope_let_pos = Mark.get a_name;
|
||||
scope_let_typ =
|
||||
(match Mark.remove a_io.io_input with
|
||||
| NoInput -> failwith "should not happen"
|
||||
| OnlyInput -> tau
|
||||
| Reentrant ->
|
||||
if is_func then tau
|
||||
else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos);
|
||||
scope_let_typ = input_var_typ (Mark.remove tau) a_io;
|
||||
scope_let_expr = thunked_or_nonempty_new_e;
|
||||
scope_let_kind = SubScopeVarDefinition;
|
||||
})
|
||||
@ -922,17 +940,6 @@ let translate_rules
|
||||
(Expr.Box.lift return_exp)),
|
||||
new_ctx )
|
||||
|
||||
let input_var_typ typ io_in =
|
||||
match io_in.Desugared.Ast.io_input with
|
||||
| Runtime.OnlyInput, pos -> typ, pos
|
||||
| Runtime.Reentrant, pos -> (
|
||||
match typ with
|
||||
| TArrow _ -> typ, pos
|
||||
| _ ->
|
||||
( TArrow ([TLit TUnit, pos], (typ, pos)),
|
||||
pos ))
|
||||
| Runtime.NoInput, _ -> invalid_arg "input_var_typ"
|
||||
|
||||
(* From a scope declaration and definitions, create the corresponding scope body
|
||||
wrapped in the appropriate call convention. *)
|
||||
let translate_scope_decl
|
||||
@ -1032,7 +1039,8 @@ let translate_scope_decl
|
||||
scope_let_kind = DestructuringInputStruct;
|
||||
scope_let_next = next;
|
||||
scope_let_pos = pos_sigma;
|
||||
scope_let_typ = input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
|
||||
scope_let_typ =
|
||||
input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
|
||||
scope_let_expr =
|
||||
( EStructAccess
|
||||
{ name = scope_input_struct_name; e = r; field },
|
||||
@ -1097,7 +1105,10 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
||||
{
|
||||
scope_input_name = StructField.fresh (s, Mark.get info);
|
||||
scope_input_io = vis.Desugared.Ast.io_input;
|
||||
scope_input_typ = Mark.remove (input_var_typ (Mark.remove typ) vis);
|
||||
scope_input_typ =
|
||||
Mark.remove (input_var_typ (Mark.remove typ) vis);
|
||||
scope_input_thunked =
|
||||
input_var_needs_thunking (Mark.remove typ) vis;
|
||||
})
|
||||
scope.Scopelang.Ast.scope_sig
|
||||
in
|
||||
@ -1144,9 +1155,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
||||
let fields =
|
||||
ScopeVar.Map.fold
|
||||
(fun _ sivc acc ->
|
||||
let pos =
|
||||
Mark.get (StructField.get_info sivc.scope_input_name)
|
||||
in
|
||||
let pos = Mark.get (StructField.get_info sivc.scope_input_name) in
|
||||
StructField.Map.add sivc.scope_input_name
|
||||
(sivc.scope_input_typ, pos)
|
||||
acc)
|
||||
@ -1171,7 +1180,6 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
||||
(gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules);
|
||||
}
|
||||
in
|
||||
Message.emit_debug "STRUCTS: %a" (StructName.Map.format_keys ~pp_sep:Format.pp_print_space) decl_ctx.ctx_structs;
|
||||
let top_ctx =
|
||||
let toplevel_vars =
|
||||
TopdefName.Map.mapi
|
||||
|
@ -481,7 +481,7 @@ let rec runtime_to_val :
|
||||
| TAny -> assert false
|
||||
|
||||
and val_to_runtime :
|
||||
type d e .
|
||||
type d e.
|
||||
(decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) ->
|
||||
decl_ctx ->
|
||||
typ ->
|
||||
@ -542,8 +542,7 @@ and val_to_runtime :
|
||||
curry [] targs
|
||||
| _ ->
|
||||
Message.raise_internal_error
|
||||
"Could not convert value of type %a to runtime: %a"
|
||||
(Print.typ ctx) ty
|
||||
"Could not convert value of type %a to runtime: %a" (Print.typ ctx) ty
|
||||
Expr.format v
|
||||
|
||||
let rec evaluate_expr :
|
||||
|
125
tests/test_modules/good/mod_def_context.catala_en
Normal file
125
tests/test_modules/good/mod_def_context.catala_en
Normal file
@ -0,0 +1,125 @@
|
||||
Testing adequacy of the scope calling convention with various types of
|
||||
parameters (reentrant, functions ...) ; and different calls (through subscopes
|
||||
or direct scope calls). The main part of the test is in `mod_use_context`.
|
||||
|
||||
> Module Mod_def_context
|
||||
|
||||
```catala-metadata
|
||||
declaration scope S:
|
||||
context output ci content integer
|
||||
context output cm content money
|
||||
context output cfun1 content decimal depends on x content integer
|
||||
input output cfun2 content decimal depends on x content integer
|
||||
```
|
||||
|
||||
```catala
|
||||
scope S:
|
||||
definition ci equals 0
|
||||
definition cm equals $0
|
||||
definition cfun1 of x equals x / 2
|
||||
```
|
||||
|
||||
Now testing direct calls within the same module
|
||||
|
||||
```catala
|
||||
declaration third content decimal
|
||||
depends on x content integer
|
||||
equals x / 3
|
||||
|
||||
declaration quarter content decimal
|
||||
depends on x content integer
|
||||
equals x / 4
|
||||
```
|
||||
|
||||
```catala
|
||||
declaration scope Stest:
|
||||
output o1 content S
|
||||
output o2 content S
|
||||
output x11 content decimal
|
||||
output x12 content decimal
|
||||
output x21 content decimal
|
||||
output x22 content decimal
|
||||
|
||||
scope Stest:
|
||||
definition o1 equals
|
||||
output of S with { -- cfun2: quarter }
|
||||
definition o2 equals
|
||||
output of S with {
|
||||
-- ci: 1
|
||||
-- cm: $1
|
||||
-- cfun1: third
|
||||
-- cfun2: quarter
|
||||
}
|
||||
definition x11 equals o1.cfun1 of 24
|
||||
definition x12 equals o1.cfun2 of 24
|
||||
definition x21 equals o2.cfun1 of 24
|
||||
definition x22 equals o2.cfun2 of 24
|
||||
```
|
||||
|
||||
```catala-test-inline
|
||||
$ catala interpret -s Stest
|
||||
[RESULT] Computation successful! Results:
|
||||
[RESULT]
|
||||
o1 = S { -- ci: 0 -- cm: $0.00 -- cfun1: <function> -- cfun2: <function> }
|
||||
[RESULT]
|
||||
o2 = S { -- ci: 1 -- cm: $1.00 -- cfun1: <function> -- cfun2: <function> }
|
||||
[RESULT] x11 = 12.0
|
||||
[RESULT] x12 = 6.0
|
||||
[RESULT] x21 = 8.0
|
||||
[RESULT] x22 = 6.0
|
||||
```
|
||||
|
||||
### Testing subscopes (with and without context override)
|
||||
|
||||
```catala
|
||||
declaration scope TestSubDefault:
|
||||
sub scope S
|
||||
output ci content integer
|
||||
output cm content money
|
||||
output x11 content decimal
|
||||
output x12 content decimal
|
||||
|
||||
scope TestSubDefault:
|
||||
definition sub.cfun2 of x equals quarter of x
|
||||
definition ci equals sub.ci
|
||||
definition cm equals sub.cm
|
||||
definition x11 equals sub.cfun1 of 24
|
||||
definition x12 equals sub.cfun2 of 24
|
||||
```
|
||||
|
||||
```catala-test-inline
|
||||
$ catala interpret -s TestSubDefault
|
||||
[RESULT] Computation successful! Results:
|
||||
[RESULT] ci = 0
|
||||
[RESULT] cm = $0.00
|
||||
[RESULT] x11 = 12.0
|
||||
[RESULT] x12 = 6.0
|
||||
```
|
||||
|
||||
```catala
|
||||
declaration scope TestSubOverride:
|
||||
sub scope S
|
||||
output ci content integer
|
||||
output cm content money
|
||||
output x21 content decimal
|
||||
output x22 content decimal
|
||||
|
||||
scope TestSubOverride:
|
||||
definition sub.ci equals 1
|
||||
definition sub.cm equals $1
|
||||
definition sub.cfun1 of x equals third of x
|
||||
definition sub.cfun2 of x equals quarter of x
|
||||
definition ci equals sub.ci
|
||||
definition cm equals sub.cm
|
||||
definition x21 equals sub.cfun1 of 24
|
||||
definition x22 equals sub.cfun2 of 24
|
||||
```
|
||||
|
||||
```catala-test-inline
|
||||
$ catala interpret -s TestSubOverride
|
||||
[RESULT] Computation successful! Results:
|
||||
[RESULT] ci = 1
|
||||
[RESULT] cm = $1.00
|
||||
[RESULT] x21 = 8.0
|
||||
[RESULT] x22 = 6.0
|
||||
```
|
Loading…
Reference in New Issue
Block a user