mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Fix handling of context vars with all call cases
This commit is contained in:
parent
af8ff472a5
commit
61ec34e3d9
@ -27,6 +27,10 @@ type scope_input_var_ctx = {
|
|||||||
scope_input_name : StructField.t;
|
scope_input_name : StructField.t;
|
||||||
scope_input_io : Runtime.io_input Mark.pos;
|
scope_input_io : Runtime.io_input Mark.pos;
|
||||||
scope_input_typ : naked_typ;
|
scope_input_typ : naked_typ;
|
||||||
|
scope_input_thunked : bool;
|
||||||
|
(* For reentrant variables: if true, the type t of the field has been
|
||||||
|
changed to (unit -> t). Otherwise, the type was already a function and
|
||||||
|
wasn't changed so no additional wrapping will be needed *)
|
||||||
}
|
}
|
||||||
|
|
||||||
type 'm scope_ref =
|
type 'm scope_ref =
|
||||||
@ -193,19 +197,30 @@ let collapse_similar_outcomes (type m) (excepts : m Scopelang.Ast.expr list) :
|
|||||||
in
|
in
|
||||||
excepts
|
excepts
|
||||||
|
|
||||||
let thunk_scope_arg ~is_func io_in e =
|
let input_var_needs_thunking typ io_in =
|
||||||
(* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so
|
(* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so
|
||||||
that we can put them in default terms at the initialisation of the function
|
that we can put them in default terms at the initialisation of the function
|
||||||
body, allowing an empty error to recover the default value. *)
|
body, allowing an empty error to recover the default value. *)
|
||||||
let silent_var = Var.make "_" in
|
match Mark.remove io_in.Desugared.Ast.io_input, typ with
|
||||||
let pos = Mark.get io_in in
|
| Runtime.Reentrant, TArrow _ ->
|
||||||
match Mark.remove io_in with
|
false (* we don't need to thunk expressions that are already functions *)
|
||||||
| Runtime.NoInput -> invalid_arg "thunk_scope_arg"
|
| Runtime.Reentrant, _ -> true
|
||||||
| Runtime.OnlyInput -> Expr.eerroronempty e (Mark.get e)
|
| _ -> false
|
||||||
| Runtime.Reentrant ->
|
|
||||||
(* we don't need to thunk expressions that are already functions *)
|
let input_var_typ typ io_in =
|
||||||
if is_func then e
|
let pos = Mark.get io_in.Desugared.Ast.io_input in
|
||||||
else Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
|
if input_var_needs_thunking typ io_in then
|
||||||
|
TArrow ([TLit TUnit, pos], (typ, pos)), pos
|
||||||
|
else typ, pos
|
||||||
|
|
||||||
|
let thunk_scope_arg var_ctx e =
|
||||||
|
match var_ctx.scope_input_io, var_ctx.scope_input_thunked with
|
||||||
|
| (Runtime.NoInput, _), _ -> invalid_arg "thunk_scope_arg"
|
||||||
|
| (Runtime.OnlyInput, _), false -> Expr.eerroronempty e (Mark.get e)
|
||||||
|
| (Runtime.Reentrant, _), false -> e
|
||||||
|
| (Runtime.Reentrant, pos), true ->
|
||||||
|
Expr.make_abs [| Var.make "_" |] e [TLit TUnit, pos] pos
|
||||||
|
| _ -> assert false
|
||||||
|
|
||||||
let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||||
'm Ast.expr boxed =
|
'm Ast.expr boxed =
|
||||||
@ -246,23 +261,27 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
|||||||
let in_var_map =
|
let in_var_map =
|
||||||
ScopeVar.Map.merge
|
ScopeVar.Map.merge
|
||||||
(fun var_name (str_field : scope_input_var_ctx option) expr ->
|
(fun var_name (str_field : scope_input_var_ctx option) expr ->
|
||||||
let expr =
|
|
||||||
match str_field, expr with
|
|
||||||
| Some { scope_input_io = Reentrant, _; _ }, None ->
|
|
||||||
Some (Expr.unbox (Expr.eemptyerror (mark_tany m pos)))
|
|
||||||
| _ -> expr
|
|
||||||
in
|
|
||||||
match str_field, expr with
|
match str_field, expr with
|
||||||
| None, None -> None
|
| None, None -> assert false
|
||||||
|
| Some ({ scope_input_io = Reentrant, iopos; _ } as var_ctx), None ->
|
||||||
|
let ty0 =
|
||||||
|
match var_ctx.scope_input_typ with
|
||||||
|
| TArrow ([_], ty) -> ty
|
||||||
|
| _ -> assert false
|
||||||
|
(* reentrant field must be thunked with correct function type at
|
||||||
|
this point *)
|
||||||
|
in
|
||||||
|
Some
|
||||||
|
( var_ctx.scope_input_name,
|
||||||
|
Expr.make_abs
|
||||||
|
[| Var.make "_" |]
|
||||||
|
(Expr.eemptyerror (Expr.with_ty m ty0))
|
||||||
|
[TAny, iopos]
|
||||||
|
pos )
|
||||||
| Some var_ctx, Some e ->
|
| Some var_ctx, Some e ->
|
||||||
Some
|
Some
|
||||||
( var_ctx.scope_input_name,
|
( var_ctx.scope_input_name,
|
||||||
thunk_scope_arg
|
thunk_scope_arg var_ctx (translate_expr ctx e) )
|
||||||
~is_func:
|
|
||||||
(match var_ctx.scope_input_typ with
|
|
||||||
| TArrow _ -> true
|
|
||||||
| _ -> false)
|
|
||||||
var_ctx.scope_input_io (translate_expr ctx e) )
|
|
||||||
| Some var_ctx, None ->
|
| Some var_ctx, None ->
|
||||||
Message.raise_multispanned_error
|
Message.raise_multispanned_error
|
||||||
[
|
[
|
||||||
@ -662,9 +681,14 @@ let translate_rule
|
|||||||
})
|
})
|
||||||
[sigma_name, pos_sigma; a_name]
|
[sigma_name, pos_sigma; a_name]
|
||||||
in
|
in
|
||||||
let is_func = match Mark.remove tau with TArrow _ -> true | _ -> false in
|
|
||||||
let thunked_or_nonempty_new_e =
|
let thunked_or_nonempty_new_e =
|
||||||
thunk_scope_arg ~is_func a_io.Desugared.Ast.io_input new_e
|
match a_io.Desugared.Ast.io_input with
|
||||||
|
| Runtime.NoInput, _ -> assert false
|
||||||
|
| Runtime.OnlyInput, _ -> Expr.eerroronempty new_e (Mark.get new_e)
|
||||||
|
| Runtime.Reentrant, pos -> (
|
||||||
|
match Mark.remove tau with
|
||||||
|
| TArrow _ -> new_e
|
||||||
|
| _ -> Expr.thunk_term new_e (Expr.with_pos pos (Mark.get new_e)))
|
||||||
in
|
in
|
||||||
( (fun next ->
|
( (fun next ->
|
||||||
Bindlib.box_apply2
|
Bindlib.box_apply2
|
||||||
@ -673,13 +697,7 @@ let translate_rule
|
|||||||
{
|
{
|
||||||
scope_let_next = next;
|
scope_let_next = next;
|
||||||
scope_let_pos = Mark.get a_name;
|
scope_let_pos = Mark.get a_name;
|
||||||
scope_let_typ =
|
scope_let_typ = input_var_typ (Mark.remove tau) a_io;
|
||||||
(match Mark.remove a_io.io_input with
|
|
||||||
| NoInput -> failwith "should not happen"
|
|
||||||
| OnlyInput -> tau
|
|
||||||
| Reentrant ->
|
|
||||||
if is_func then tau
|
|
||||||
else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos);
|
|
||||||
scope_let_expr = thunked_or_nonempty_new_e;
|
scope_let_expr = thunked_or_nonempty_new_e;
|
||||||
scope_let_kind = SubScopeVarDefinition;
|
scope_let_kind = SubScopeVarDefinition;
|
||||||
})
|
})
|
||||||
@ -922,17 +940,6 @@ let translate_rules
|
|||||||
(Expr.Box.lift return_exp)),
|
(Expr.Box.lift return_exp)),
|
||||||
new_ctx )
|
new_ctx )
|
||||||
|
|
||||||
let input_var_typ typ io_in =
|
|
||||||
match io_in.Desugared.Ast.io_input with
|
|
||||||
| Runtime.OnlyInput, pos -> typ, pos
|
|
||||||
| Runtime.Reentrant, pos -> (
|
|
||||||
match typ with
|
|
||||||
| TArrow _ -> typ, pos
|
|
||||||
| _ ->
|
|
||||||
( TArrow ([TLit TUnit, pos], (typ, pos)),
|
|
||||||
pos ))
|
|
||||||
| Runtime.NoInput, _ -> invalid_arg "input_var_typ"
|
|
||||||
|
|
||||||
(* From a scope declaration and definitions, create the corresponding scope body
|
(* From a scope declaration and definitions, create the corresponding scope body
|
||||||
wrapped in the appropriate call convention. *)
|
wrapped in the appropriate call convention. *)
|
||||||
let translate_scope_decl
|
let translate_scope_decl
|
||||||
@ -1032,7 +1039,8 @@ let translate_scope_decl
|
|||||||
scope_let_kind = DestructuringInputStruct;
|
scope_let_kind = DestructuringInputStruct;
|
||||||
scope_let_next = next;
|
scope_let_next = next;
|
||||||
scope_let_pos = pos_sigma;
|
scope_let_pos = pos_sigma;
|
||||||
scope_let_typ = input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
|
scope_let_typ =
|
||||||
|
input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
|
||||||
scope_let_expr =
|
scope_let_expr =
|
||||||
( EStructAccess
|
( EStructAccess
|
||||||
{ name = scope_input_struct_name; e = r; field },
|
{ name = scope_input_struct_name; e = r; field },
|
||||||
@ -1045,11 +1053,11 @@ let translate_scope_decl
|
|||||||
in
|
in
|
||||||
Bindlib.box_apply
|
Bindlib.box_apply
|
||||||
(fun scope_body_expr ->
|
(fun scope_body_expr ->
|
||||||
{
|
{
|
||||||
scope_body_expr;
|
scope_body_expr;
|
||||||
scope_body_input_struct = scope_input_struct_name;
|
scope_body_input_struct = scope_input_struct_name;
|
||||||
scope_body_output_struct = scope_return_struct_name;
|
scope_body_output_struct = scope_return_struct_name;
|
||||||
})
|
})
|
||||||
(Bindlib.bind_var scope_input_var
|
(Bindlib.bind_var scope_input_var
|
||||||
(input_destructurings rules_with_return_expr))
|
(input_destructurings rules_with_return_expr))
|
||||||
|
|
||||||
@ -1097,7 +1105,10 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
|||||||
{
|
{
|
||||||
scope_input_name = StructField.fresh (s, Mark.get info);
|
scope_input_name = StructField.fresh (s, Mark.get info);
|
||||||
scope_input_io = vis.Desugared.Ast.io_input;
|
scope_input_io = vis.Desugared.Ast.io_input;
|
||||||
scope_input_typ = Mark.remove (input_var_typ (Mark.remove typ) vis);
|
scope_input_typ =
|
||||||
|
Mark.remove (input_var_typ (Mark.remove typ) vis);
|
||||||
|
scope_input_thunked =
|
||||||
|
input_var_needs_thunking (Mark.remove typ) vis;
|
||||||
})
|
})
|
||||||
scope.Scopelang.Ast.scope_sig
|
scope.Scopelang.Ast.scope_sig
|
||||||
in
|
in
|
||||||
@ -1141,18 +1152,16 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
|||||||
let add_scope_in_structs scope_sigs structs =
|
let add_scope_in_structs scope_sigs structs =
|
||||||
ScopeName.Map.fold
|
ScopeName.Map.fold
|
||||||
(fun _ scope_sig_ctx acc ->
|
(fun _ scope_sig_ctx acc ->
|
||||||
let fields =
|
let fields =
|
||||||
ScopeVar.Map.fold
|
ScopeVar.Map.fold
|
||||||
(fun _ sivc acc ->
|
(fun _ sivc acc ->
|
||||||
let pos =
|
let pos = Mark.get (StructField.get_info sivc.scope_input_name) in
|
||||||
Mark.get (StructField.get_info sivc.scope_input_name)
|
StructField.Map.add sivc.scope_input_name
|
||||||
in
|
(sivc.scope_input_typ, pos)
|
||||||
StructField.Map.add sivc.scope_input_name
|
acc)
|
||||||
(sivc.scope_input_typ, pos)
|
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
|
||||||
acc)
|
in
|
||||||
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
|
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
|
||||||
in
|
|
||||||
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
|
|
||||||
scope_sigs.scope_sigs structs
|
scope_sigs.scope_sigs structs
|
||||||
in
|
in
|
||||||
let rec gather_module_in_structs acc sctx =
|
let rec gather_module_in_structs acc sctx =
|
||||||
@ -1171,7 +1180,6 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
|||||||
(gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules);
|
(gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules);
|
||||||
}
|
}
|
||||||
in
|
in
|
||||||
Message.emit_debug "STRUCTS: %a" (StructName.Map.format_keys ~pp_sep:Format.pp_print_space) decl_ctx.ctx_structs;
|
|
||||||
let top_ctx =
|
let top_ctx =
|
||||||
let toplevel_vars =
|
let toplevel_vars =
|
||||||
TopdefName.Map.mapi
|
TopdefName.Map.mapi
|
||||||
|
@ -425,7 +425,7 @@ let rec evaluate_operator
|
|||||||
(* /S\ dark magic here. This relies both on internals of [Lcalc.to_ocaml] *and*
|
(* /S\ dark magic here. This relies both on internals of [Lcalc.to_ocaml] *and*
|
||||||
of the OCaml runtime *)
|
of the OCaml runtime *)
|
||||||
let rec runtime_to_val :
|
let rec runtime_to_val :
|
||||||
type d e.
|
type d e.
|
||||||
(decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) ->
|
(decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) ->
|
||||||
decl_ctx ->
|
decl_ctx ->
|
||||||
'm mark ->
|
'm mark ->
|
||||||
@ -481,7 +481,7 @@ let rec runtime_to_val :
|
|||||||
| TAny -> assert false
|
| TAny -> assert false
|
||||||
|
|
||||||
and val_to_runtime :
|
and val_to_runtime :
|
||||||
type d e .
|
type d e.
|
||||||
(decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) ->
|
(decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) ->
|
||||||
decl_ctx ->
|
decl_ctx ->
|
||||||
typ ->
|
typ ->
|
||||||
@ -542,8 +542,7 @@ and val_to_runtime :
|
|||||||
curry [] targs
|
curry [] targs
|
||||||
| _ ->
|
| _ ->
|
||||||
Message.raise_internal_error
|
Message.raise_internal_error
|
||||||
"Could not convert value of type %a to runtime: %a"
|
"Could not convert value of type %a to runtime: %a" (Print.typ ctx) ty
|
||||||
(Print.typ ctx) ty
|
|
||||||
Expr.format v
|
Expr.format v
|
||||||
|
|
||||||
let rec evaluate_expr :
|
let rec evaluate_expr :
|
||||||
|
125
tests/test_modules/good/mod_def_context.catala_en
Normal file
125
tests/test_modules/good/mod_def_context.catala_en
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
Testing adequacy of the scope calling convention with various types of
|
||||||
|
parameters (reentrant, functions ...) ; and different calls (through subscopes
|
||||||
|
or direct scope calls). The main part of the test is in `mod_use_context`.
|
||||||
|
|
||||||
|
> Module Mod_def_context
|
||||||
|
|
||||||
|
```catala-metadata
|
||||||
|
declaration scope S:
|
||||||
|
context output ci content integer
|
||||||
|
context output cm content money
|
||||||
|
context output cfun1 content decimal depends on x content integer
|
||||||
|
input output cfun2 content decimal depends on x content integer
|
||||||
|
```
|
||||||
|
|
||||||
|
```catala
|
||||||
|
scope S:
|
||||||
|
definition ci equals 0
|
||||||
|
definition cm equals $0
|
||||||
|
definition cfun1 of x equals x / 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Now testing direct calls within the same module
|
||||||
|
|
||||||
|
```catala
|
||||||
|
declaration third content decimal
|
||||||
|
depends on x content integer
|
||||||
|
equals x / 3
|
||||||
|
|
||||||
|
declaration quarter content decimal
|
||||||
|
depends on x content integer
|
||||||
|
equals x / 4
|
||||||
|
```
|
||||||
|
|
||||||
|
```catala
|
||||||
|
declaration scope Stest:
|
||||||
|
output o1 content S
|
||||||
|
output o2 content S
|
||||||
|
output x11 content decimal
|
||||||
|
output x12 content decimal
|
||||||
|
output x21 content decimal
|
||||||
|
output x22 content decimal
|
||||||
|
|
||||||
|
scope Stest:
|
||||||
|
definition o1 equals
|
||||||
|
output of S with { -- cfun2: quarter }
|
||||||
|
definition o2 equals
|
||||||
|
output of S with {
|
||||||
|
-- ci: 1
|
||||||
|
-- cm: $1
|
||||||
|
-- cfun1: third
|
||||||
|
-- cfun2: quarter
|
||||||
|
}
|
||||||
|
definition x11 equals o1.cfun1 of 24
|
||||||
|
definition x12 equals o1.cfun2 of 24
|
||||||
|
definition x21 equals o2.cfun1 of 24
|
||||||
|
definition x22 equals o2.cfun2 of 24
|
||||||
|
```
|
||||||
|
|
||||||
|
```catala-test-inline
|
||||||
|
$ catala interpret -s Stest
|
||||||
|
[RESULT] Computation successful! Results:
|
||||||
|
[RESULT]
|
||||||
|
o1 = S { -- ci: 0 -- cm: $0.00 -- cfun1: <function> -- cfun2: <function> }
|
||||||
|
[RESULT]
|
||||||
|
o2 = S { -- ci: 1 -- cm: $1.00 -- cfun1: <function> -- cfun2: <function> }
|
||||||
|
[RESULT] x11 = 12.0
|
||||||
|
[RESULT] x12 = 6.0
|
||||||
|
[RESULT] x21 = 8.0
|
||||||
|
[RESULT] x22 = 6.0
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing subscopes (with and without context override)
|
||||||
|
|
||||||
|
```catala
|
||||||
|
declaration scope TestSubDefault:
|
||||||
|
sub scope S
|
||||||
|
output ci content integer
|
||||||
|
output cm content money
|
||||||
|
output x11 content decimal
|
||||||
|
output x12 content decimal
|
||||||
|
|
||||||
|
scope TestSubDefault:
|
||||||
|
definition sub.cfun2 of x equals quarter of x
|
||||||
|
definition ci equals sub.ci
|
||||||
|
definition cm equals sub.cm
|
||||||
|
definition x11 equals sub.cfun1 of 24
|
||||||
|
definition x12 equals sub.cfun2 of 24
|
||||||
|
```
|
||||||
|
|
||||||
|
```catala-test-inline
|
||||||
|
$ catala interpret -s TestSubDefault
|
||||||
|
[RESULT] Computation successful! Results:
|
||||||
|
[RESULT] ci = 0
|
||||||
|
[RESULT] cm = $0.00
|
||||||
|
[RESULT] x11 = 12.0
|
||||||
|
[RESULT] x12 = 6.0
|
||||||
|
```
|
||||||
|
|
||||||
|
```catala
|
||||||
|
declaration scope TestSubOverride:
|
||||||
|
sub scope S
|
||||||
|
output ci content integer
|
||||||
|
output cm content money
|
||||||
|
output x21 content decimal
|
||||||
|
output x22 content decimal
|
||||||
|
|
||||||
|
scope TestSubOverride:
|
||||||
|
definition sub.ci equals 1
|
||||||
|
definition sub.cm equals $1
|
||||||
|
definition sub.cfun1 of x equals third of x
|
||||||
|
definition sub.cfun2 of x equals quarter of x
|
||||||
|
definition ci equals sub.ci
|
||||||
|
definition cm equals sub.cm
|
||||||
|
definition x21 equals sub.cfun1 of 24
|
||||||
|
definition x22 equals sub.cfun2 of 24
|
||||||
|
```
|
||||||
|
|
||||||
|
```catala-test-inline
|
||||||
|
$ catala interpret -s TestSubOverride
|
||||||
|
[RESULT] Computation successful! Results:
|
||||||
|
[RESULT] ci = 1
|
||||||
|
[RESULT] cm = $1.00
|
||||||
|
[RESULT] x21 = 8.0
|
||||||
|
[RESULT] x22 = 6.0
|
||||||
|
```
|
Loading…
Reference in New Issue
Block a user