Fix issue 362 with wrong encoding of context variables that are functions (#363)

This commit is contained in:
Denis Merigoux 2022-12-07 19:22:01 +01:00 committed by GitHub
commit 90291c55c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 203 additions and 57 deletions

View File

@ -23,6 +23,12 @@ type scope_var_ctx = {
scope_var_io : Desugared.Ast.io; scope_var_io : Desugared.Ast.io;
} }
type scope_input_var_ctx = {
scope_input_name : StructField.t;
scope_input_io : Desugared.Ast.io_input Marked.pos;
scope_input_typ : naked_typ;
}
type 'm scope_sig_ctx = { type 'm scope_sig_ctx = {
scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *) scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *)
scope_sig_scope_var : 'm Ast.expr Var.t; (** Var representing the scope *) scope_sig_scope_var : 'm Ast.expr Var.t; (** Var representing the scope *)
@ -30,8 +36,7 @@ type 'm scope_sig_ctx = {
(** Var representing the scope input inside the scope func *) (** Var representing the scope input inside the scope func *)
scope_sig_input_struct : StructName.t; (** Scope input *) scope_sig_input_struct : StructName.t; (** Scope input *)
scope_sig_output_struct : StructName.t; (** Scope output *) scope_sig_output_struct : StructName.t; (** Scope output *)
scope_sig_in_fields : scope_sig_in_fields : scope_input_var_ctx ScopeVar.Map.t;
(StructField.t * Desugared.Ast.io_input Marked.pos) ScopeVar.Map.t;
(** Mapping between the input scope variables and the input struct fields. *) (** Mapping between the input scope variables and the input struct fields. *)
scope_sig_out_fields : StructField.t ScopeVar.Map.t; scope_sig_out_fields : StructField.t ScopeVar.Map.t;
(** Mapping between the output scope variables and the output struct (** Mapping between the output scope variables and the output struct
@ -81,7 +86,46 @@ let pos_mark_mk (type a m) (e : (a, m mark) gexpr) :
let pos_mark_as e = pos_mark (Marked.get_mark e) in let pos_mark_as e = pos_mark (Marked.get_mark e) in
pos_mark, pos_mark_as pos_mark, pos_mark_as
let merge_defaults caller callee = let merge_defaults
~(is_func : bool)
(caller : (dcalc, 'm mark) boxed_gexpr)
(callee : (dcalc, 'm mark) boxed_gexpr) : (dcalc, 'm mark) boxed_gexpr =
(* the merging of the two defaults, from the reentrant caller and the callee,
is straightfoward in the general case and a little subtler when the
variable being defined is a function. *)
if is_func then
let m_callee = Marked.get_mark callee in
let unboxed_callee = Expr.unbox callee in
match Marked.unmark unboxed_callee with
| EAbs { binder; tys } ->
let vars, body = Bindlib.unmbind binder in
let m_body = Marked.get_mark body in
let caller =
let m = Marked.get_mark caller in
let pos = Expr.mark_pos m in
Expr.make_app caller
(List.map2
(fun (var : (dcalc, 'm mark) naked_gexpr Bindlib.var) ty ->
Expr.evar var
(* we have to correctly propagate types when doing this
rewriting *)
(Expr.with_ty m_body ~pos:(Expr.mark_pos m_body) ty))
(Array.to_list vars) tys)
pos
in
let ltrue =
Expr.elit (LBool true)
(Expr.with_ty m_callee
(Marked.mark (Expr.mark_pos m_callee) (TLit TBool)))
in
let d = Expr.edefault [caller] ltrue (Expr.rebox body) m_body in
Expr.make_abs vars
(Expr.eerroronempty d m_body)
tys (Expr.mark_pos m_callee)
| _ -> assert false
(* should not happen because there should always be a lambda at the
beginning of a default with a function type *)
else
let caller = let caller =
let m = Marked.get_mark caller in let m = Marked.get_mark caller in
let pos = Expr.mark_pos m in let pos = Expr.mark_pos m in
@ -95,7 +139,7 @@ let merge_defaults caller callee =
Expr.elit (LBool true) Expr.elit (LBool true)
(Expr.with_ty m (Marked.mark (Expr.mark_pos m) (TLit TBool))) (Expr.with_ty m (Marked.mark (Expr.mark_pos m) (TLit TBool)))
in in
Expr.edefault [caller] ltrue callee m Expr.eerroronempty (Expr.edefault [caller] ltrue callee m) m
in in
body body
@ -150,7 +194,7 @@ let collapse_similar_outcomes (type m) (excepts : m Scopelang.Ast.expr list) :
in in
excepts excepts
let thunk_scope_arg io_in e = let thunk_scope_arg ~is_func io_in e =
(* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so (* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so
that we can put them in default terms at the initialisation of the function that we can put them in default terms at the initialisation of the function
body, allowing an empty error to recover the default value. *) body, allowing an empty error to recover the default value. *)
@ -160,7 +204,9 @@ let thunk_scope_arg io_in e =
| Desugared.Ast.NoInput -> invalid_arg "thunk_scope_arg" | Desugared.Ast.NoInput -> invalid_arg "thunk_scope_arg"
| Desugared.Ast.OnlyInput -> Expr.eerroronempty e (Marked.get_mark e) | Desugared.Ast.OnlyInput -> Expr.eerroronempty e (Marked.get_mark e)
| Desugared.Ast.Reentrant -> | Desugared.Ast.Reentrant ->
Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos (* we don't need to thunk expressions that are already functions *)
if is_func then e
else Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
'm Ast.expr boxed = 'm Ast.expr boxed =
@ -214,23 +260,31 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
let sc_sig = ScopeName.Map.find scope ctx.scopes_parameters in let sc_sig = ScopeName.Map.find scope ctx.scopes_parameters in
let in_var_map = let in_var_map =
ScopeVar.Map.merge ScopeVar.Map.merge
(fun var_name str_field expr -> (fun var_name (str_field : scope_input_var_ctx option) expr ->
let expr = let expr =
match str_field, expr with match str_field, expr with
| Some (_, (Desugared.Ast.Reentrant, _)), None -> | Some { scope_input_io = Desugared.Ast.Reentrant, _; _ }, None ->
Some (Expr.unbox (Expr.elit LEmptyError (mark_tany m pos))) Some (Expr.unbox (Expr.elit LEmptyError (mark_tany m pos)))
| _ -> expr | _ -> expr
in in
match str_field, expr with match str_field, expr with
| None, None -> None | None, None -> None
| Some (fld, io_in), Some e -> | Some var_ctx, Some e ->
Some (fld, thunk_scope_arg io_in (translate_expr ctx e)) Some
| Some (fld, _), None -> ( var_ctx.scope_input_name,
thunk_scope_arg
~is_func:
(match var_ctx.scope_input_typ with
| TArrow _ -> true
| _ -> false)
var_ctx.scope_input_io (translate_expr ctx e) )
| Some var_ctx, None ->
Errors.raise_multispanned_error Errors.raise_multispanned_error
[ [
None, pos; None, pos;
( Some "Declaration of the missing input variable", ( Some "Declaration of the missing input variable",
Marked.get_mark (StructField.get_info fld) ); Marked.get_mark
(StructField.get_info var_ctx.scope_input_name) );
] ]
"Definition of input variable '%a' missing in this scope call" "Definition of input variable '%a' missing in this scope call"
ScopeVar.format_t var_name ScopeVar.format_t var_name
@ -387,14 +441,15 @@ let translate_rule
let new_e = translate_expr ctx e in let new_e = translate_expr ctx e in
let a_expr = Expr.make_var a_var (pos_mark var_def_pos) in let a_expr = Expr.make_var a_var (pos_mark var_def_pos) in
let merged_expr = let merged_expr =
Expr.eerroronempty match Marked.unmark a_io.io_input with
(match Marked.unmark a_io.io_input with
| OnlyInput -> failwith "should not happen" | OnlyInput -> failwith "should not happen"
(* scopelang should not contain any definitions of input only (* scopelang should not contain any definitions of input only variables *)
variables *) | Reentrant ->
| Reentrant -> merge_defaults a_expr new_e merge_defaults
| NoInput -> new_e) ~is_func:
(pos_mark_as a_name) (match Marked.unmark tau with TArrow _ -> true | _ -> false)
a_expr new_e
| NoInput -> Expr.eerroronempty new_e (pos_mark_as a_name)
in in
let merged_expr = let merged_expr =
tag_with_log_entry merged_expr tag_with_log_entry merged_expr
@ -438,8 +493,11 @@ let translate_rule
(VarDef (Marked.unmark tau)) (VarDef (Marked.unmark tau))
[sigma_name, pos_sigma; a_name] [sigma_name, pos_sigma; a_name]
in in
let is_func =
match Marked.unmark tau with TArrow _ -> true | _ -> false
in
let thunked_or_nonempty_new_e = let thunked_or_nonempty_new_e =
thunk_scope_arg a_io.Desugared.Ast.io_input new_e thunk_scope_arg ~is_func a_io.Desugared.Ast.io_input new_e
in in
( (fun next -> ( (fun next ->
Bindlib.box_apply2 Bindlib.box_apply2
@ -453,7 +511,8 @@ let translate_rule
| NoInput -> failwith "should not happen" | NoInput -> failwith "should not happen"
| OnlyInput -> tau | OnlyInput -> tau
| Reentrant -> | Reentrant ->
TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos); if is_func then tau
else TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos);
scope_let_expr = thunked_or_nonempty_new_e; scope_let_expr = thunked_or_nonempty_new_e;
scope_let_kind = SubScopeVarDefinition; scope_let_kind = SubScopeVarDefinition;
}) })
@ -521,9 +580,9 @@ let translate_rule
Expr.make_var a_var (mark_tany m pos_call) Expr.make_var a_var (mark_tany m pos_call)
in in
let field = let field =
Marked.unmark
(ScopeVar.Map.find subvar.scope_var_name (ScopeVar.Map.find subvar.scope_var_name
subscope_sig.scope_sig_in_fields) subscope_sig.scope_sig_in_fields)
.scope_input_name
in in
StructField.Map.add field e acc) StructField.Map.add field e acc)
StructField.Map.empty all_subscope_input_vars StructField.Map.empty all_subscope_input_vars
@ -739,18 +798,21 @@ let translate_scope_decl
let input_var_typ (var_ctx : scope_var_ctx) = let input_var_typ (var_ctx : scope_var_ctx) =
match Marked.unmark var_ctx.scope_var_io.io_input with match Marked.unmark var_ctx.scope_var_io.io_input with
| OnlyInput -> var_ctx.scope_var_typ, pos_sigma | OnlyInput -> var_ctx.scope_var_typ, pos_sigma
| Reentrant -> | Reentrant -> (
match var_ctx.scope_var_typ with
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
| _ ->
( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)), ( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)),
pos_sigma ) pos_sigma ))
| NoInput -> failwith "should not happen" | NoInput -> failwith "should not happen"
in in
let input_destructurings next = let input_destructurings next =
List.fold_right List.fold_right
(fun (var_ctx, v) next -> (fun (var_ctx, v) next ->
let field = let field =
Marked.unmark
(ScopeVar.Map.find var_ctx.scope_var_name (ScopeVar.Map.find var_ctx.scope_var_name
scope_sig.scope_sig_in_fields) scope_sig.scope_sig_in_fields)
.scope_input_name
in in
Bindlib.box_apply2 Bindlib.box_apply2
(fun next r -> (fun next r ->
@ -775,7 +837,9 @@ let translate_scope_decl
List.fold_left List.fold_left
(fun acc (var_ctx, _) -> (fun acc (var_ctx, _) ->
let var = var_ctx.scope_var_name in let var = var_ctx.scope_var_name in
let field, _ = ScopeVar.Map.find var scope_sig.scope_sig_in_fields in let field =
(ScopeVar.Map.find var scope_sig.scope_sig_in_fields).scope_input_name
in
StructField.Map.add field (input_var_typ var_ctx) acc) StructField.Map.add field (input_var_typ var_ctx) acc)
StructField.Map.empty scope_input_variables StructField.Map.empty scope_input_variables
in in
@ -820,15 +884,19 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
in in
let scope_sig_in_fields = let scope_sig_in_fields =
ScopeVar.Map.filter_map ScopeVar.Map.filter_map
(fun dvar (_, vis) -> (fun dvar (typ, vis) ->
match Marked.unmark vis.Desugared.Ast.io_input with match Marked.unmark vis.Desugared.Ast.io_input with
| NoInput -> None | NoInput -> None
| OnlyInput | Reentrant -> | OnlyInput | Reentrant ->
let info = ScopeVar.get_info dvar in let info = ScopeVar.get_info dvar in
let s = Marked.unmark info ^ "_in" in let s = Marked.unmark info ^ "_in" in
Some Some
( StructField.fresh (s, Marked.get_mark info), {
vis.Desugared.Ast.io_input )) scope_input_name =
StructField.fresh (s, Marked.get_mark info);
scope_input_io = vis.Desugared.Ast.io_input;
scope_input_typ = Marked.unmark typ;
})
scope.scope_sig scope.scope_sig
in in
{ {

View File

@ -504,8 +504,11 @@ let interpret_program :
StructField.Map.map StructField.Map.map
(fun ty -> (fun ty ->
match Marked.unmark ty with match Marked.unmark ty with
| TArrow ((TLit TUnit, _), ty_in) -> | TArrow (ty_in, ty_out) ->
Expr.empty_thunked_term (Expr.with_ty mark_e ty_in) Expr.make_abs
[| Var.make "_" |]
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
[ty_in] (Expr.mark_pos mark_e)
| _ -> | _ ->
Errors.raise_spanned_error (Marked.get_mark ty) Errors.raise_spanned_error (Marked.get_mark ty)
"This scope needs input arguments to be executed. But the Catala \ "This scope needs input arguments to be executed. But the Catala \

View File

@ -191,6 +191,7 @@ let def_map_to_tree
when to place the toplevel binding in the case of functions. *) when to place the toplevel binding in the case of functions. *)
let rec rule_tree_to_expr let rec rule_tree_to_expr
~(toplevel : bool) ~(toplevel : bool)
~(is_reentrant_var : bool)
(ctx : ctx) (ctx : ctx)
(def_pos : Pos.t) (def_pos : Pos.t)
(is_func : Desugared.Ast.expr Var.t option) (is_func : Desugared.Ast.expr Var.t option)
@ -271,7 +272,9 @@ let rec rule_tree_to_expr
emark emark
in in
let exceptions = let exceptions =
List.map (rule_tree_to_expr ~toplevel:false ctx def_pos is_func) exceptions List.map
(rule_tree_to_expr ~toplevel:false ~is_reentrant_var ctx def_pos is_func)
exceptions
in in
let default = let default =
Expr.make_default exceptions Expr.make_default exceptions
@ -283,8 +286,13 @@ let rec rule_tree_to_expr
| Some new_param, Some (_, typ) -> | Some new_param, Some (_, typ) ->
if toplevel then if toplevel then
(* When we're creating a function from multiple defaults, we must check (* When we're creating a function from multiple defaults, we must check
that the result returned by the function is not empty *) that the result returned by the function is not empty, unless we're
let default = Expr.eerroronempty default emark in dealing with a context variable which is reentrant (either in the
caller or callee). In this case the ErrorOnEmpty will be added later in
the scopelang->dcalc translation. *)
let default =
if is_reentrant_var then default else Expr.eerroronempty default emark
in
Expr.make_abs Expr.make_abs
[| Var.Map.find new_param ctx.var_mapping |] [| Var.Map.find new_param ctx.var_mapping |]
default [typ] def_pos default [typ] def_pos
@ -348,6 +356,11 @@ let translate_def
| OnlyInput -> true | OnlyInput -> true
| _ -> false | _ -> false
in in
let is_reentrant =
match Marked.unmark io.Desugared.Ast.io_input with
| Reentrant -> true
| _ -> false
in
let top_value = let top_value =
if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then
(* We add the bottom [false] value for conditions, only for the scope (* We add the bottom [false] value for conditions, only for the scope
@ -385,10 +398,14 @@ let translate_def
will not be provided by the calee scope, it has to be placed in the will not be provided by the calee scope, it has to be placed in the
caller. *) caller. *)
then then
Expr.elit LEmptyError let m = Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info } in
(Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info }) let empty_error = Expr.elit LEmptyError m in
match is_def_func_param_typ with
| Some ty ->
Expr.make_abs [| Var.make "_" |] empty_error [ty] (Expr.mark_pos m)
| _ -> empty_error
else else
rule_tree_to_expr ~toplevel:true ctx rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx
(Desugared.Ast.ScopeDef.get_position def_info) (Desugared.Ast.ScopeDef.get_position def_info)
(Option.map (fun _ -> Var.make "param") is_def_func_param_typ) (Option.map (fun _ -> Var.make "param") is_def_func_param_typ)
(match top_list, top_value with (match top_list, top_value with

View File

@ -82,6 +82,18 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : typed expr) :
(* context sub-scope variables *) (* context sub-scope variables *)
let _, body = Bindlib.unmbind binder in let _, body = Bindlib.unmbind binder in
body body
| EAbs { binder; _ } -> (
(* context scope variables *)
let _, body = Bindlib.unmbind binder in
match Marked.unmark body with
| EErrorOnEmpty e -> e
| _ ->
Errors.raise_spanned_error (Expr.pos e)
"Internal error: this expression does not have the structure expected \
by the VC generator:\n\
%a"
(Expr.format ~debug:true ctx.decl)
e)
| EErrorOnEmpty d -> | EErrorOnEmpty d ->
d (* input subscope variables and non-input scope variable *) d (* input subscope variables and non-input scope variable *)
| _ -> | _ ->

View File

@ -0,0 +1,46 @@
## Test
```catala
declaration scope A:
context f content integer depends on integer
declaration scope B:
input b content boolean
a scope A
scope A:
definition f of x equals x + 1
scope B:
definition a.f of x under condition b and x > 0 consequence equals x - 1
```
```catala-test-inline
$ catala Scopelang -s B
let scope B (b: bool|input) =
let a.f : integer → integer =
λ (param: integer) → ⟨b && param > 0 ⊢ param - 1⟩;
call A[a]
```
```catala-test-inline
$ catala Dcalc -s A
let A =
λ (A_in: A_in {"f_in": integer → integer}) →
let f : integer → integer = A_in."f_in" in
let f1 : integer → integer =
λ (param: integer) → error_empty
⟨f param | true ⊢ ⟨true ⊢ param + 1⟩⟩ in
A { }
```
```catala-test-inline
$ catala Dcalc -s B
let B =
λ (B_in: B_in {"b_in": bool}) →
let b : bool = B_in."b_in" in
let a.f : integer → integer =
λ (param: integer) → ⟨b && param > 0 ⊢ param - 1⟩ in
let result : A {} = A (A_in { "f_in"= a.f }) in
B { }
```