Fix issue 362 with wrong encoding of context variables that are functions (#363)

This commit is contained in:
Denis Merigoux 2022-12-07 19:22:01 +01:00 committed by GitHub
commit 90291c55c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 203 additions and 57 deletions

View File

@ -23,6 +23,12 @@ type scope_var_ctx = {
scope_var_io : Desugared.Ast.io;
}
type scope_input_var_ctx = {
scope_input_name : StructField.t;
scope_input_io : Desugared.Ast.io_input Marked.pos;
scope_input_typ : naked_typ;
}
type 'm scope_sig_ctx = {
scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *)
scope_sig_scope_var : 'm Ast.expr Var.t; (** Var representing the scope *)
@ -30,8 +36,7 @@ type 'm scope_sig_ctx = {
(** Var representing the scope input inside the scope func *)
scope_sig_input_struct : StructName.t; (** Scope input *)
scope_sig_output_struct : StructName.t; (** Scope output *)
scope_sig_in_fields :
(StructField.t * Desugared.Ast.io_input Marked.pos) ScopeVar.Map.t;
scope_sig_in_fields : scope_input_var_ctx ScopeVar.Map.t;
(** Mapping between the input scope variables and the input struct fields. *)
scope_sig_out_fields : StructField.t ScopeVar.Map.t;
(** Mapping between the output scope variables and the output struct
@ -81,23 +86,62 @@ let pos_mark_mk (type a m) (e : (a, m mark) gexpr) :
let pos_mark_as e = pos_mark (Marked.get_mark e) in
pos_mark, pos_mark_as
let merge_defaults caller callee =
let caller =
let m = Marked.get_mark caller in
let pos = Expr.mark_pos m in
Expr.make_app caller
[Expr.elit LUnit (Expr.with_ty m (Marked.mark pos (TLit TUnit)))]
pos
in
let body =
let m = Marked.get_mark callee in
let ltrue =
Expr.elit (LBool true)
(Expr.with_ty m (Marked.mark (Expr.mark_pos m) (TLit TBool)))
let merge_defaults
~(is_func : bool)
(caller : (dcalc, 'm mark) boxed_gexpr)
(callee : (dcalc, 'm mark) boxed_gexpr) : (dcalc, 'm mark) boxed_gexpr =
(* the merging of the two defaults, from the reentrant caller and the callee,
is straightfoward in the general case and a little subtler when the
variable being defined is a function. *)
if is_func then
let m_callee = Marked.get_mark callee in
let unboxed_callee = Expr.unbox callee in
match Marked.unmark unboxed_callee with
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let m_body = Marked.get_mark body in
let caller =
let m = Marked.get_mark caller in
let pos = Expr.mark_pos m in
Expr.make_app caller
(List.map2
(fun (var : (dcalc, 'm mark) naked_gexpr Bindlib.var) ty ->
Expr.evar var
(* we have to correctly propagate types when doing this
rewriting *)
(Expr.with_ty m_body ~pos:(Expr.mark_pos m_body) ty))
(Array.to_list vars) tys)
pos
in
let ltrue =
Expr.elit (LBool true)
(Expr.with_ty m_callee
(Marked.mark (Expr.mark_pos m_callee) (TLit TBool)))
in
let d = Expr.edefault [caller] ltrue (Expr.rebox body) m_body in
Expr.make_abs vars
(Expr.eerroronempty d m_body)
tys (Expr.mark_pos m_callee)
| _ -> assert false
(* should not happen because there should always be a lambda at the
beginning of a default with a function type *)
else
let caller =
let m = Marked.get_mark caller in
let pos = Expr.mark_pos m in
Expr.make_app caller
[Expr.elit LUnit (Expr.with_ty m (Marked.mark pos (TLit TUnit)))]
pos
in
Expr.edefault [caller] ltrue callee m
in
body
let body =
let m = Marked.get_mark callee in
let ltrue =
Expr.elit (LBool true)
(Expr.with_ty m (Marked.mark (Expr.mark_pos m) (TLit TBool)))
in
Expr.eerroronempty (Expr.edefault [caller] ltrue callee m) m
in
body
let tag_with_log_entry
(e : 'm Ast.expr boxed)
@ -150,7 +194,7 @@ let collapse_similar_outcomes (type m) (excepts : m Scopelang.Ast.expr list) :
in
excepts
let thunk_scope_arg io_in e =
let thunk_scope_arg ~is_func io_in e =
(* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so
that we can put them in default terms at the initialisation of the function
body, allowing an empty error to recover the default value. *)
@ -160,7 +204,9 @@ let thunk_scope_arg io_in e =
| Desugared.Ast.NoInput -> invalid_arg "thunk_scope_arg"
| Desugared.Ast.OnlyInput -> Expr.eerroronempty e (Marked.get_mark e)
| Desugared.Ast.Reentrant ->
Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
(* we don't need to thunk expressions that are already functions *)
if is_func then e
else Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
'm Ast.expr boxed =
@ -214,23 +260,31 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
let sc_sig = ScopeName.Map.find scope ctx.scopes_parameters in
let in_var_map =
ScopeVar.Map.merge
(fun var_name str_field expr ->
(fun var_name (str_field : scope_input_var_ctx option) expr ->
let expr =
match str_field, expr with
| Some (_, (Desugared.Ast.Reentrant, _)), None ->
| Some { scope_input_io = Desugared.Ast.Reentrant, _; _ }, None ->
Some (Expr.unbox (Expr.elit LEmptyError (mark_tany m pos)))
| _ -> expr
in
match str_field, expr with
| None, None -> None
| Some (fld, io_in), Some e ->
Some (fld, thunk_scope_arg io_in (translate_expr ctx e))
| Some (fld, _), None ->
| Some var_ctx, Some e ->
Some
( var_ctx.scope_input_name,
thunk_scope_arg
~is_func:
(match var_ctx.scope_input_typ with
| TArrow _ -> true
| _ -> false)
var_ctx.scope_input_io (translate_expr ctx e) )
| Some var_ctx, None ->
Errors.raise_multispanned_error
[
None, pos;
( Some "Declaration of the missing input variable",
Marked.get_mark (StructField.get_info fld) );
Marked.get_mark
(StructField.get_info var_ctx.scope_input_name) );
]
"Definition of input variable '%a' missing in this scope call"
ScopeVar.format_t var_name
@ -387,14 +441,15 @@ let translate_rule
let new_e = translate_expr ctx e in
let a_expr = Expr.make_var a_var (pos_mark var_def_pos) in
let merged_expr =
Expr.eerroronempty
(match Marked.unmark a_io.io_input with
| OnlyInput -> failwith "should not happen"
(* scopelang should not contain any definitions of input only
variables *)
| Reentrant -> merge_defaults a_expr new_e
| NoInput -> new_e)
(pos_mark_as a_name)
match Marked.unmark a_io.io_input with
| OnlyInput -> failwith "should not happen"
(* scopelang should not contain any definitions of input only variables *)
| Reentrant ->
merge_defaults
~is_func:
(match Marked.unmark tau with TArrow _ -> true | _ -> false)
a_expr new_e
| NoInput -> Expr.eerroronempty new_e (pos_mark_as a_name)
in
let merged_expr =
tag_with_log_entry merged_expr
@ -438,8 +493,11 @@ let translate_rule
(VarDef (Marked.unmark tau))
[sigma_name, pos_sigma; a_name]
in
let is_func =
match Marked.unmark tau with TArrow _ -> true | _ -> false
in
let thunked_or_nonempty_new_e =
thunk_scope_arg a_io.Desugared.Ast.io_input new_e
thunk_scope_arg ~is_func a_io.Desugared.Ast.io_input new_e
in
( (fun next ->
Bindlib.box_apply2
@ -453,7 +511,8 @@ let translate_rule
| NoInput -> failwith "should not happen"
| OnlyInput -> tau
| Reentrant ->
TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos);
if is_func then tau
else TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos);
scope_let_expr = thunked_or_nonempty_new_e;
scope_let_kind = SubScopeVarDefinition;
})
@ -521,9 +580,9 @@ let translate_rule
Expr.make_var a_var (mark_tany m pos_call)
in
let field =
Marked.unmark
(ScopeVar.Map.find subvar.scope_var_name
subscope_sig.scope_sig_in_fields)
(ScopeVar.Map.find subvar.scope_var_name
subscope_sig.scope_sig_in_fields)
.scope_input_name
in
StructField.Map.add field e acc)
StructField.Map.empty all_subscope_input_vars
@ -739,18 +798,21 @@ let translate_scope_decl
let input_var_typ (var_ctx : scope_var_ctx) =
match Marked.unmark var_ctx.scope_var_io.io_input with
| OnlyInput -> var_ctx.scope_var_typ, pos_sigma
| Reentrant ->
( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)),
pos_sigma )
| Reentrant -> (
match var_ctx.scope_var_typ with
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
| _ ->
( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)),
pos_sigma ))
| NoInput -> failwith "should not happen"
in
let input_destructurings next =
List.fold_right
(fun (var_ctx, v) next ->
let field =
Marked.unmark
(ScopeVar.Map.find var_ctx.scope_var_name
scope_sig.scope_sig_in_fields)
(ScopeVar.Map.find var_ctx.scope_var_name
scope_sig.scope_sig_in_fields)
.scope_input_name
in
Bindlib.box_apply2
(fun next r ->
@ -775,7 +837,9 @@ let translate_scope_decl
List.fold_left
(fun acc (var_ctx, _) ->
let var = var_ctx.scope_var_name in
let field, _ = ScopeVar.Map.find var scope_sig.scope_sig_in_fields in
let field =
(ScopeVar.Map.find var scope_sig.scope_sig_in_fields).scope_input_name
in
StructField.Map.add field (input_var_typ var_ctx) acc)
StructField.Map.empty scope_input_variables
in
@ -820,15 +884,19 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
in
let scope_sig_in_fields =
ScopeVar.Map.filter_map
(fun dvar (_, vis) ->
(fun dvar (typ, vis) ->
match Marked.unmark vis.Desugared.Ast.io_input with
| NoInput -> None
| OnlyInput | Reentrant ->
let info = ScopeVar.get_info dvar in
let s = Marked.unmark info ^ "_in" in
Some
( StructField.fresh (s, Marked.get_mark info),
vis.Desugared.Ast.io_input ))
{
scope_input_name =
StructField.fresh (s, Marked.get_mark info);
scope_input_io = vis.Desugared.Ast.io_input;
scope_input_typ = Marked.unmark typ;
})
scope.scope_sig
in
{

View File

@ -504,8 +504,11 @@ let interpret_program :
StructField.Map.map
(fun ty ->
match Marked.unmark ty with
| TArrow ((TLit TUnit, _), ty_in) ->
Expr.empty_thunked_term (Expr.with_ty mark_e ty_in)
| TArrow (ty_in, ty_out) ->
Expr.make_abs
[| Var.make "_" |]
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
[ty_in] (Expr.mark_pos mark_e)
| _ ->
Errors.raise_spanned_error (Marked.get_mark ty)
"This scope needs input arguments to be executed. But the Catala \

View File

@ -191,6 +191,7 @@ let def_map_to_tree
when to place the toplevel binding in the case of functions. *)
let rec rule_tree_to_expr
~(toplevel : bool)
~(is_reentrant_var : bool)
(ctx : ctx)
(def_pos : Pos.t)
(is_func : Desugared.Ast.expr Var.t option)
@ -271,7 +272,9 @@ let rec rule_tree_to_expr
emark
in
let exceptions =
List.map (rule_tree_to_expr ~toplevel:false ctx def_pos is_func) exceptions
List.map
(rule_tree_to_expr ~toplevel:false ~is_reentrant_var ctx def_pos is_func)
exceptions
in
let default =
Expr.make_default exceptions
@ -283,8 +286,13 @@ let rec rule_tree_to_expr
| Some new_param, Some (_, typ) ->
if toplevel then
(* When we're creating a function from multiple defaults, we must check
that the result returned by the function is not empty *)
let default = Expr.eerroronempty default emark in
that the result returned by the function is not empty, unless we're
dealing with a context variable which is reentrant (either in the
caller or callee). In this case the ErrorOnEmpty will be added later in
the scopelang->dcalc translation. *)
let default =
if is_reentrant_var then default else Expr.eerroronempty default emark
in
Expr.make_abs
[| Var.Map.find new_param ctx.var_mapping |]
default [typ] def_pos
@ -348,6 +356,11 @@ let translate_def
| OnlyInput -> true
| _ -> false
in
let is_reentrant =
match Marked.unmark io.Desugared.Ast.io_input with
| Reentrant -> true
| _ -> false
in
let top_value =
if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then
(* We add the bottom [false] value for conditions, only for the scope
@ -385,10 +398,14 @@ let translate_def
will not be provided by the calee scope, it has to be placed in the
caller. *)
then
Expr.elit LEmptyError
(Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info })
let m = Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info } in
let empty_error = Expr.elit LEmptyError m in
match is_def_func_param_typ with
| Some ty ->
Expr.make_abs [| Var.make "_" |] empty_error [ty] (Expr.mark_pos m)
| _ -> empty_error
else
rule_tree_to_expr ~toplevel:true ctx
rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx
(Desugared.Ast.ScopeDef.get_position def_info)
(Option.map (fun _ -> Var.make "param") is_def_func_param_typ)
(match top_list, top_value with

View File

@ -82,6 +82,18 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : typed expr) :
(* context sub-scope variables *)
let _, body = Bindlib.unmbind binder in
body
| EAbs { binder; _ } -> (
(* context scope variables *)
let _, body = Bindlib.unmbind binder in
match Marked.unmark body with
| EErrorOnEmpty e -> e
| _ ->
Errors.raise_spanned_error (Expr.pos e)
"Internal error: this expression does not have the structure expected \
by the VC generator:\n\
%a"
(Expr.format ~debug:true ctx.decl)
e)
| EErrorOnEmpty d ->
d (* input subscope variables and non-input scope variable *)
| _ ->

View File

@ -0,0 +1,46 @@
## Test
```catala
declaration scope A:
context f content integer depends on integer
declaration scope B:
input b content boolean
a scope A
scope A:
definition f of x equals x + 1
scope B:
definition a.f of x under condition b and x > 0 consequence equals x - 1
```
```catala-test-inline
$ catala Scopelang -s B
let scope B (b: bool|input) =
let a.f : integer → integer =
λ (param: integer) → ⟨b && param > 0 ⊢ param - 1⟩;
call A[a]
```
```catala-test-inline
$ catala Dcalc -s A
let A =
λ (A_in: A_in {"f_in": integer → integer}) →
let f : integer → integer = A_in."f_in" in
let f1 : integer → integer =
λ (param: integer) → error_empty
⟨f param | true ⊢ ⟨true ⊢ param + 1⟩⟩ in
A { }
```
```catala-test-inline
$ catala Dcalc -s B
let B =
λ (B_in: B_in {"b_in": bool}) →
let b : bool = B_in."b_in" in
let a.f : integer → integer =
λ (param: integer) → ⟨b && param > 0 ⊢ param - 1⟩ in
let result : A {} = A (A_in { "f_in"= a.f }) in
B { }
```