Fix handling of context vars with all call cases

This commit is contained in:
Louis Gesbert 2023-10-12 14:41:57 +02:00
parent af8ff472a5
commit 61ec34e3d9
3 changed files with 199 additions and 67 deletions

View File

@ -27,6 +27,10 @@ type scope_input_var_ctx = {
scope_input_name : StructField.t; scope_input_name : StructField.t;
scope_input_io : Runtime.io_input Mark.pos; scope_input_io : Runtime.io_input Mark.pos;
scope_input_typ : naked_typ; scope_input_typ : naked_typ;
scope_input_thunked : bool;
(* For reentrant variables: if true, the type t of the field has been
changed to (unit -> t). Otherwise, the type was already a function and
wasn't changed so no additional wrapping will be needed *)
} }
type 'm scope_ref = type 'm scope_ref =
@ -193,19 +197,30 @@ let collapse_similar_outcomes (type m) (excepts : m Scopelang.Ast.expr list) :
in in
excepts excepts
let thunk_scope_arg ~is_func io_in e = let input_var_needs_thunking typ io_in =
(* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so (* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so
that we can put them in default terms at the initialisation of the function that we can put them in default terms at the initialisation of the function
body, allowing an empty error to recover the default value. *) body, allowing an empty error to recover the default value. *)
let silent_var = Var.make "_" in match Mark.remove io_in.Desugared.Ast.io_input, typ with
let pos = Mark.get io_in in | Runtime.Reentrant, TArrow _ ->
match Mark.remove io_in with false (* we don't need to thunk expressions that are already functions *)
| Runtime.NoInput -> invalid_arg "thunk_scope_arg" | Runtime.Reentrant, _ -> true
| Runtime.OnlyInput -> Expr.eerroronempty e (Mark.get e) | _ -> false
| Runtime.Reentrant ->
(* we don't need to thunk expressions that are already functions *) let input_var_typ typ io_in =
if is_func then e let pos = Mark.get io_in.Desugared.Ast.io_input in
else Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos if input_var_needs_thunking typ io_in then
TArrow ([TLit TUnit, pos], (typ, pos)), pos
else typ, pos
let thunk_scope_arg var_ctx e =
match var_ctx.scope_input_io, var_ctx.scope_input_thunked with
| (Runtime.NoInput, _), _ -> invalid_arg "thunk_scope_arg"
| (Runtime.OnlyInput, _), false -> Expr.eerroronempty e (Mark.get e)
| (Runtime.Reentrant, _), false -> e
| (Runtime.Reentrant, pos), true ->
Expr.make_abs [| Var.make "_" |] e [TLit TUnit, pos] pos
| _ -> assert false
let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
'm Ast.expr boxed = 'm Ast.expr boxed =
@ -246,23 +261,27 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
let in_var_map = let in_var_map =
ScopeVar.Map.merge ScopeVar.Map.merge
(fun var_name (str_field : scope_input_var_ctx option) expr -> (fun var_name (str_field : scope_input_var_ctx option) expr ->
let expr =
match str_field, expr with
| Some { scope_input_io = Reentrant, _; _ }, None ->
Some (Expr.unbox (Expr.eemptyerror (mark_tany m pos)))
| _ -> expr
in
match str_field, expr with match str_field, expr with
| None, None -> None | None, None -> assert false
| Some ({ scope_input_io = Reentrant, iopos; _ } as var_ctx), None ->
let ty0 =
match var_ctx.scope_input_typ with
| TArrow ([_], ty) -> ty
| _ -> assert false
(* reentrant field must be thunked with correct function type at
this point *)
in
Some
( var_ctx.scope_input_name,
Expr.make_abs
[| Var.make "_" |]
(Expr.eemptyerror (Expr.with_ty m ty0))
[TAny, iopos]
pos )
| Some var_ctx, Some e -> | Some var_ctx, Some e ->
Some Some
( var_ctx.scope_input_name, ( var_ctx.scope_input_name,
thunk_scope_arg thunk_scope_arg var_ctx (translate_expr ctx e) )
~is_func:
(match var_ctx.scope_input_typ with
| TArrow _ -> true
| _ -> false)
var_ctx.scope_input_io (translate_expr ctx e) )
| Some var_ctx, None -> | Some var_ctx, None ->
Message.raise_multispanned_error Message.raise_multispanned_error
[ [
@ -662,9 +681,14 @@ let translate_rule
}) })
[sigma_name, pos_sigma; a_name] [sigma_name, pos_sigma; a_name]
in in
let is_func = match Mark.remove tau with TArrow _ -> true | _ -> false in
let thunked_or_nonempty_new_e = let thunked_or_nonempty_new_e =
thunk_scope_arg ~is_func a_io.Desugared.Ast.io_input new_e match a_io.Desugared.Ast.io_input with
| Runtime.NoInput, _ -> assert false
| Runtime.OnlyInput, _ -> Expr.eerroronempty new_e (Mark.get new_e)
| Runtime.Reentrant, pos -> (
match Mark.remove tau with
| TArrow _ -> new_e
| _ -> Expr.thunk_term new_e (Expr.with_pos pos (Mark.get new_e)))
in in
( (fun next -> ( (fun next ->
Bindlib.box_apply2 Bindlib.box_apply2
@ -673,13 +697,7 @@ let translate_rule
{ {
scope_let_next = next; scope_let_next = next;
scope_let_pos = Mark.get a_name; scope_let_pos = Mark.get a_name;
scope_let_typ = scope_let_typ = input_var_typ (Mark.remove tau) a_io;
(match Mark.remove a_io.io_input with
| NoInput -> failwith "should not happen"
| OnlyInput -> tau
| Reentrant ->
if is_func then tau
else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos);
scope_let_expr = thunked_or_nonempty_new_e; scope_let_expr = thunked_or_nonempty_new_e;
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
}) })
@ -922,17 +940,6 @@ let translate_rules
(Expr.Box.lift return_exp)), (Expr.Box.lift return_exp)),
new_ctx ) new_ctx )
let input_var_typ typ io_in =
match io_in.Desugared.Ast.io_input with
| Runtime.OnlyInput, pos -> typ, pos
| Runtime.Reentrant, pos -> (
match typ with
| TArrow _ -> typ, pos
| _ ->
( TArrow ([TLit TUnit, pos], (typ, pos)),
pos ))
| Runtime.NoInput, _ -> invalid_arg "input_var_typ"
(* From a scope declaration and definitions, create the corresponding scope body (* From a scope declaration and definitions, create the corresponding scope body
wrapped in the appropriate call convention. *) wrapped in the appropriate call convention. *)
let translate_scope_decl let translate_scope_decl
@ -1032,7 +1039,8 @@ let translate_scope_decl
scope_let_kind = DestructuringInputStruct; scope_let_kind = DestructuringInputStruct;
scope_let_next = next; scope_let_next = next;
scope_let_pos = pos_sigma; scope_let_pos = pos_sigma;
scope_let_typ = input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io; scope_let_typ =
input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
scope_let_expr = scope_let_expr =
( EStructAccess ( EStructAccess
{ name = scope_input_struct_name; e = r; field }, { name = scope_input_struct_name; e = r; field },
@ -1045,11 +1053,11 @@ let translate_scope_decl
in in
Bindlib.box_apply Bindlib.box_apply
(fun scope_body_expr -> (fun scope_body_expr ->
{ {
scope_body_expr; scope_body_expr;
scope_body_input_struct = scope_input_struct_name; scope_body_input_struct = scope_input_struct_name;
scope_body_output_struct = scope_return_struct_name; scope_body_output_struct = scope_return_struct_name;
}) })
(Bindlib.bind_var scope_input_var (Bindlib.bind_var scope_input_var
(input_destructurings rules_with_return_expr)) (input_destructurings rules_with_return_expr))
@ -1097,7 +1105,10 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
{ {
scope_input_name = StructField.fresh (s, Mark.get info); scope_input_name = StructField.fresh (s, Mark.get info);
scope_input_io = vis.Desugared.Ast.io_input; scope_input_io = vis.Desugared.Ast.io_input;
scope_input_typ = Mark.remove (input_var_typ (Mark.remove typ) vis); scope_input_typ =
Mark.remove (input_var_typ (Mark.remove typ) vis);
scope_input_thunked =
input_var_needs_thunking (Mark.remove typ) vis;
}) })
scope.Scopelang.Ast.scope_sig scope.Scopelang.Ast.scope_sig
in in
@ -1141,18 +1152,16 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
let add_scope_in_structs scope_sigs structs = let add_scope_in_structs scope_sigs structs =
ScopeName.Map.fold ScopeName.Map.fold
(fun _ scope_sig_ctx acc -> (fun _ scope_sig_ctx acc ->
let fields = let fields =
ScopeVar.Map.fold ScopeVar.Map.fold
(fun _ sivc acc -> (fun _ sivc acc ->
let pos = let pos = Mark.get (StructField.get_info sivc.scope_input_name) in
Mark.get (StructField.get_info sivc.scope_input_name) StructField.Map.add sivc.scope_input_name
in (sivc.scope_input_typ, pos)
StructField.Map.add sivc.scope_input_name acc)
(sivc.scope_input_typ, pos) scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
acc) in
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
in
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
scope_sigs.scope_sigs structs scope_sigs.scope_sigs structs
in in
let rec gather_module_in_structs acc sctx = let rec gather_module_in_structs acc sctx =
@ -1171,7 +1180,6 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
(gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules); (gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules);
} }
in in
Message.emit_debug "STRUCTS: %a" (StructName.Map.format_keys ~pp_sep:Format.pp_print_space) decl_ctx.ctx_structs;
let top_ctx = let top_ctx =
let toplevel_vars = let toplevel_vars =
TopdefName.Map.mapi TopdefName.Map.mapi

View File

@ -425,7 +425,7 @@ let rec evaluate_operator
(* /S\ dark magic here. This relies both on internals of [Lcalc.to_ocaml] *and* (* /S\ dark magic here. This relies both on internals of [Lcalc.to_ocaml] *and*
of the OCaml runtime *) of the OCaml runtime *)
let rec runtime_to_val : let rec runtime_to_val :
type d e. type d e.
(decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) -> (decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) ->
decl_ctx -> decl_ctx ->
'm mark -> 'm mark ->
@ -481,7 +481,7 @@ let rec runtime_to_val :
| TAny -> assert false | TAny -> assert false
and val_to_runtime : and val_to_runtime :
type d e . type d e.
(decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) -> (decl_ctx -> ((d, e, _) astk, 'm) gexpr -> ((d, e, _) astk, 'm) gexpr) ->
decl_ctx -> decl_ctx ->
typ -> typ ->
@ -542,8 +542,7 @@ and val_to_runtime :
curry [] targs curry [] targs
| _ -> | _ ->
Message.raise_internal_error Message.raise_internal_error
"Could not convert value of type %a to runtime: %a" "Could not convert value of type %a to runtime: %a" (Print.typ ctx) ty
(Print.typ ctx) ty
Expr.format v Expr.format v
let rec evaluate_expr : let rec evaluate_expr :

View File

@ -0,0 +1,125 @@
Testing adequacy of the scope calling convention with various types of
parameters (reentrant, functions ...) ; and different calls (through subscopes
or direct scope calls). The main part of the test is in `mod_use_context`.
> Module Mod_def_context
```catala-metadata
declaration scope S:
context output ci content integer
context output cm content money
context output cfun1 content decimal depends on x content integer
input output cfun2 content decimal depends on x content integer
```
```catala
scope S:
definition ci equals 0
definition cm equals $0
definition cfun1 of x equals x / 2
```
Now testing direct calls within the same module
```catala
declaration third content decimal
depends on x content integer
equals x / 3
declaration quarter content decimal
depends on x content integer
equals x / 4
```
```catala
declaration scope Stest:
output o1 content S
output o2 content S
output x11 content decimal
output x12 content decimal
output x21 content decimal
output x22 content decimal
scope Stest:
definition o1 equals
output of S with { -- cfun2: quarter }
definition o2 equals
output of S with {
-- ci: 1
-- cm: $1
-- cfun1: third
-- cfun2: quarter
}
definition x11 equals o1.cfun1 of 24
definition x12 equals o1.cfun2 of 24
definition x21 equals o2.cfun1 of 24
definition x22 equals o2.cfun2 of 24
```
```catala-test-inline
$ catala interpret -s Stest
[RESULT] Computation successful! Results:
[RESULT]
o1 = S { -- ci: 0 -- cm: $0.00 -- cfun1: <function> -- cfun2: <function> }
[RESULT]
o2 = S { -- ci: 1 -- cm: $1.00 -- cfun1: <function> -- cfun2: <function> }
[RESULT] x11 = 12.0
[RESULT] x12 = 6.0
[RESULT] x21 = 8.0
[RESULT] x22 = 6.0
```
### Testing subscopes (with and without context override)
```catala
declaration scope TestSubDefault:
sub scope S
output ci content integer
output cm content money
output x11 content decimal
output x12 content decimal
scope TestSubDefault:
definition sub.cfun2 of x equals quarter of x
definition ci equals sub.ci
definition cm equals sub.cm
definition x11 equals sub.cfun1 of 24
definition x12 equals sub.cfun2 of 24
```
```catala-test-inline
$ catala interpret -s TestSubDefault
[RESULT] Computation successful! Results:
[RESULT] ci = 0
[RESULT] cm = $0.00
[RESULT] x11 = 12.0
[RESULT] x12 = 6.0
```
```catala
declaration scope TestSubOverride:
sub scope S
output ci content integer
output cm content money
output x21 content decimal
output x22 content decimal
scope TestSubOverride:
definition sub.ci equals 1
definition sub.cm equals $1
definition sub.cfun1 of x equals third of x
definition sub.cfun2 of x equals quarter of x
definition ci equals sub.ci
definition cm equals sub.cm
definition x21 equals sub.cfun1 of 24
definition x22 equals sub.cfun2 of 24
```
```catala-test-inline
$ catala interpret -s TestSubOverride
[RESULT] Computation successful! Results:
[RESULT] ci = 1
[RESULT] cm = $1.00
[RESULT] x21 = 8.0
[RESULT] x22 = 6.0
```