mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Fix 362 (was harder than expected and unit tests helped catch subsequent encoding bugs!)
This commit is contained in:
parent
eee9946847
commit
e448a1a1b4
@ -23,6 +23,12 @@ type scope_var_ctx = {
|
||||
scope_var_io : Desugared.Ast.io;
|
||||
}
|
||||
|
||||
type scope_input_var_ctx = {
|
||||
scope_input_name : StructFieldName.t;
|
||||
scope_input_io : Desugared.Ast.io_input Marked.pos;
|
||||
scope_input_typ : naked_typ;
|
||||
}
|
||||
|
||||
type 'm scope_sig_ctx = {
|
||||
scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *)
|
||||
scope_sig_scope_var : 'm Ast.expr Var.t; (** Var representing the scope *)
|
||||
@ -30,8 +36,7 @@ type 'm scope_sig_ctx = {
|
||||
(** Var representing the scope input inside the scope func *)
|
||||
scope_sig_input_struct : StructName.t; (** Scope input *)
|
||||
scope_sig_output_struct : StructName.t; (** Scope output *)
|
||||
scope_sig_in_fields :
|
||||
(StructFieldName.t * Desugared.Ast.io_input Marked.pos) ScopeVarMap.t;
|
||||
scope_sig_in_fields : scope_input_var_ctx ScopeVarMap.t;
|
||||
(** Mapping between the input scope variables and the input struct fields. *)
|
||||
scope_sig_out_fields : StructFieldName.t ScopeVarMap.t;
|
||||
(** Mapping between the output scope variables and the output struct
|
||||
@ -80,23 +85,66 @@ let pos_mark_mk (type a m) (e : (a, m mark) gexpr) :
|
||||
let pos_mark_as e = pos_mark (Marked.get_mark e) in
|
||||
pos_mark, pos_mark_as
|
||||
|
||||
let merge_defaults caller callee =
|
||||
let caller =
|
||||
let m = Marked.get_mark caller in
|
||||
let pos = Expr.mark_pos m in
|
||||
Expr.make_app caller
|
||||
[Expr.elit LUnit (Expr.with_ty m (Marked.mark pos (TLit TUnit)))]
|
||||
pos
|
||||
in
|
||||
let body =
|
||||
let m = Marked.get_mark callee in
|
||||
let ltrue =
|
||||
Expr.elit (LBool true)
|
||||
(Expr.with_ty m (Marked.mark (Expr.mark_pos m) (TLit TBool)))
|
||||
let merge_defaults
|
||||
~(is_func : bool)
|
||||
(caller : (dcalc, 'm mark) boxed_gexpr)
|
||||
(callee : (dcalc, 'm mark) boxed_gexpr) : (dcalc, 'm mark) boxed_gexpr =
|
||||
(* the merging of the two defaults, from the reentrant caller and the callee,
|
||||
is straightfoward in the general case and a little subtler when the
|
||||
variable being defined is a function. *)
|
||||
if is_func then
|
||||
let m_callee = Marked.get_mark callee in
|
||||
Bindlib.unbox
|
||||
(Bindlib.box_apply
|
||||
(fun naked_unboxed_callee ->
|
||||
match naked_unboxed_callee with
|
||||
| EAbs { binder; tys } ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let m_body = Marked.get_mark body in
|
||||
let caller =
|
||||
let m = Marked.get_mark caller in
|
||||
let pos = Expr.mark_pos m in
|
||||
Expr.make_app caller
|
||||
(List.map2
|
||||
(fun (var : (dcalc, 'm mark) naked_gexpr Bindlib.var) ty ->
|
||||
Expr.evar var
|
||||
(* we have to correctly propagate types when doing this
|
||||
rewriting *)
|
||||
(Expr.with_ty m_body ~pos:(Expr.mark_pos m_body) ty))
|
||||
(Array.to_list vars) tys)
|
||||
pos
|
||||
in
|
||||
|
||||
let ltrue =
|
||||
Expr.elit (LBool true)
|
||||
(Expr.with_ty m_callee
|
||||
(Marked.mark (Expr.mark_pos m_callee) (TLit TBool)))
|
||||
in
|
||||
let d = Expr.edefault [caller] ltrue (Expr.rebox body) m_body in
|
||||
Expr.make_abs vars
|
||||
(Expr.eerroronempty d m_body)
|
||||
tys (Expr.mark_pos m_callee)
|
||||
| _ -> assert false
|
||||
(* should not happen because there should always be a lambda at the
|
||||
beginning of a default with a function type *))
|
||||
(Marked.unmark callee))
|
||||
else
|
||||
let caller =
|
||||
let m = Marked.get_mark caller in
|
||||
let pos = Expr.mark_pos m in
|
||||
Expr.make_app caller
|
||||
[Expr.elit LUnit (Expr.with_ty m (Marked.mark pos (TLit TUnit)))]
|
||||
pos
|
||||
in
|
||||
Expr.edefault [caller] ltrue callee m
|
||||
in
|
||||
body
|
||||
let body =
|
||||
let m = Marked.get_mark callee in
|
||||
let ltrue =
|
||||
Expr.elit (LBool true)
|
||||
(Expr.with_ty m (Marked.mark (Expr.mark_pos m) (TLit TBool)))
|
||||
in
|
||||
Expr.eerroronempty (Expr.edefault [caller] ltrue callee m) m
|
||||
in
|
||||
body
|
||||
|
||||
let tag_with_log_entry
|
||||
(e : 'm Ast.expr boxed)
|
||||
@ -149,7 +197,7 @@ let collapse_similar_outcomes (type m) (excepts : m Scopelang.Ast.expr list) :
|
||||
in
|
||||
excepts
|
||||
|
||||
let thunk_scope_arg io_in e =
|
||||
let thunk_scope_arg ~is_func io_in e =
|
||||
(* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so
|
||||
that we can put them in default terms at the initialisation of the function
|
||||
body, allowing an empty error to recover the default value. *)
|
||||
@ -159,7 +207,9 @@ let thunk_scope_arg io_in e =
|
||||
| Desugared.Ast.NoInput -> invalid_arg "thunk_scope_arg"
|
||||
| Desugared.Ast.OnlyInput -> Expr.eerroronempty e (Marked.get_mark e)
|
||||
| Desugared.Ast.Reentrant ->
|
||||
Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
|
||||
(* we don't need to thunk expressions that are already functions *)
|
||||
if is_func then e
|
||||
else Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
|
||||
|
||||
let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
'm Ast.expr boxed =
|
||||
@ -213,23 +263,31 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|
||||
let sc_sig = ScopeMap.find scope ctx.scopes_parameters in
|
||||
let in_var_map =
|
||||
ScopeVarMap.merge
|
||||
(fun var_name str_field expr ->
|
||||
(fun var_name (str_field : scope_input_var_ctx option) expr ->
|
||||
let expr =
|
||||
match str_field, expr with
|
||||
| Some (_, (Desugared.Ast.Reentrant, _)), None ->
|
||||
| Some { scope_input_io = Desugared.Ast.Reentrant, _; _ }, None ->
|
||||
Some (Expr.unbox (Expr.elit LEmptyError (mark_tany m pos)))
|
||||
| _ -> expr
|
||||
in
|
||||
match str_field, expr with
|
||||
| None, None -> None
|
||||
| Some (fld, io_in), Some e ->
|
||||
Some (fld, thunk_scope_arg io_in (translate_expr ctx e))
|
||||
| Some (fld, _), None ->
|
||||
| Some var_ctx, Some e ->
|
||||
Some
|
||||
( var_ctx.scope_input_name,
|
||||
thunk_scope_arg
|
||||
~is_func:
|
||||
(match var_ctx.scope_input_typ with
|
||||
| TArrow _ -> true
|
||||
| _ -> false)
|
||||
var_ctx.scope_input_io (translate_expr ctx e) )
|
||||
| Some var_ctx, None ->
|
||||
Errors.raise_multispanned_error
|
||||
[
|
||||
None, pos;
|
||||
( Some "Declaration of the missing input variable",
|
||||
Marked.get_mark (StructFieldName.get_info fld) );
|
||||
Marked.get_mark
|
||||
(StructFieldName.get_info var_ctx.scope_input_name) );
|
||||
]
|
||||
"Definition of input variable '%a' missing in this scope call"
|
||||
ScopeVar.format_t var_name
|
||||
@ -386,14 +444,15 @@ let translate_rule
|
||||
let new_e = translate_expr ctx e in
|
||||
let a_expr = Expr.make_var a_var (pos_mark var_def_pos) in
|
||||
let merged_expr =
|
||||
Expr.eerroronempty
|
||||
(match Marked.unmark a_io.io_input with
|
||||
| OnlyInput -> failwith "should not happen"
|
||||
(* scopelang should not contain any definitions of input only
|
||||
variables *)
|
||||
| Reentrant -> merge_defaults a_expr new_e
|
||||
| NoInput -> new_e)
|
||||
(pos_mark_as a_name)
|
||||
match Marked.unmark a_io.io_input with
|
||||
| OnlyInput -> failwith "should not happen"
|
||||
(* scopelang should not contain any definitions of input only variables *)
|
||||
| Reentrant ->
|
||||
merge_defaults
|
||||
~is_func:
|
||||
(match Marked.unmark tau with TArrow _ -> true | _ -> false)
|
||||
a_expr new_e
|
||||
| NoInput -> Expr.eerroronempty new_e (pos_mark_as a_name)
|
||||
in
|
||||
let merged_expr =
|
||||
tag_with_log_entry merged_expr
|
||||
@ -437,8 +496,11 @@ let translate_rule
|
||||
(VarDef (Marked.unmark tau))
|
||||
[sigma_name, pos_sigma; a_name]
|
||||
in
|
||||
let is_func =
|
||||
match Marked.unmark tau with TArrow _ -> true | _ -> false
|
||||
in
|
||||
let thunked_or_nonempty_new_e =
|
||||
thunk_scope_arg a_io.Desugared.Ast.io_input new_e
|
||||
thunk_scope_arg ~is_func a_io.Desugared.Ast.io_input new_e
|
||||
in
|
||||
( (fun next ->
|
||||
Bindlib.box_apply2
|
||||
@ -452,7 +514,8 @@ let translate_rule
|
||||
| NoInput -> failwith "should not happen"
|
||||
| OnlyInput -> tau
|
||||
| Reentrant ->
|
||||
TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos);
|
||||
if is_func then tau
|
||||
else TArrow ((TLit TUnit, var_def_pos), tau), var_def_pos);
|
||||
scope_let_expr = thunked_or_nonempty_new_e;
|
||||
scope_let_kind = SubScopeVarDefinition;
|
||||
})
|
||||
@ -520,9 +583,9 @@ let translate_rule
|
||||
Expr.make_var a_var (mark_tany m pos_call)
|
||||
in
|
||||
let field =
|
||||
Marked.unmark
|
||||
(ScopeVarMap.find subvar.scope_var_name
|
||||
subscope_sig.scope_sig_in_fields)
|
||||
(ScopeVarMap.find subvar.scope_var_name
|
||||
subscope_sig.scope_sig_in_fields)
|
||||
.scope_input_name
|
||||
in
|
||||
StructFieldMap.add field e acc)
|
||||
StructFieldMap.empty all_subscope_input_vars
|
||||
@ -738,18 +801,20 @@ let translate_scope_decl
|
||||
let input_var_typ (var_ctx : scope_var_ctx) =
|
||||
match Marked.unmark var_ctx.scope_var_io.io_input with
|
||||
| OnlyInput -> var_ctx.scope_var_typ, pos_sigma
|
||||
| Reentrant ->
|
||||
( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)),
|
||||
pos_sigma )
|
||||
| Reentrant -> (
|
||||
match var_ctx.scope_var_typ with
|
||||
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
|
||||
| _ ->
|
||||
( TArrow ((TLit TUnit, pos_sigma), (var_ctx.scope_var_typ, pos_sigma)),
|
||||
pos_sigma ))
|
||||
| NoInput -> failwith "should not happen"
|
||||
in
|
||||
let input_destructurings next =
|
||||
List.fold_right
|
||||
(fun (var_ctx, v) next ->
|
||||
let field =
|
||||
Marked.unmark
|
||||
(ScopeVarMap.find var_ctx.scope_var_name
|
||||
scope_sig.scope_sig_in_fields)
|
||||
(ScopeVarMap.find var_ctx.scope_var_name scope_sig.scope_sig_in_fields)
|
||||
.scope_input_name
|
||||
in
|
||||
Bindlib.box_apply2
|
||||
(fun next r ->
|
||||
@ -774,7 +839,9 @@ let translate_scope_decl
|
||||
List.fold_left
|
||||
(fun acc (var_ctx, _) ->
|
||||
let var = var_ctx.scope_var_name in
|
||||
let field, _ = ScopeVarMap.find var scope_sig.scope_sig_in_fields in
|
||||
let field =
|
||||
(ScopeVarMap.find var scope_sig.scope_sig_in_fields).scope_input_name
|
||||
in
|
||||
StructFieldMap.add field (input_var_typ var_ctx) acc)
|
||||
StructFieldMap.empty scope_input_variables
|
||||
in
|
||||
@ -817,15 +884,19 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
|
||||
in
|
||||
let scope_sig_in_fields =
|
||||
ScopeVarMap.filter_map
|
||||
(fun dvar (_, vis) ->
|
||||
(fun dvar (typ, vis) ->
|
||||
match Marked.unmark vis.Desugared.Ast.io_input with
|
||||
| NoInput -> None
|
||||
| OnlyInput | Reentrant ->
|
||||
let info = ScopeVar.get_info dvar in
|
||||
let s = Marked.unmark info ^ "_in" in
|
||||
Some
|
||||
( StructFieldName.fresh (s, Marked.get_mark info),
|
||||
vis.Desugared.Ast.io_input ))
|
||||
{
|
||||
scope_input_name =
|
||||
StructFieldName.fresh (s, Marked.get_mark info);
|
||||
scope_input_io = vis.Desugared.Ast.io_input;
|
||||
scope_input_typ = Marked.unmark typ;
|
||||
})
|
||||
scope.scope_sig
|
||||
in
|
||||
{
|
||||
|
@ -504,8 +504,11 @@ let interpret_program :
|
||||
StructFieldMap.map
|
||||
(fun ty ->
|
||||
match Marked.unmark ty with
|
||||
| TArrow ((TLit TUnit, _), ty_in) ->
|
||||
Expr.empty_thunked_term (Expr.with_ty mark_e ty_in)
|
||||
| TArrow (ty_in, ty_out) ->
|
||||
Expr.make_abs
|
||||
[| Var.make "_" |]
|
||||
(Bindlib.box (ELit LEmptyError), Expr.with_ty mark_e ty_out)
|
||||
[ty_in] (Expr.mark_pos mark_e)
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Marked.get_mark ty)
|
||||
"This scope needs input arguments to be executed. But the Catala \
|
||||
|
@ -171,7 +171,7 @@ let def_map_to_tree
|
||||
when to place the toplevel binding in the case of functions. *)
|
||||
let rec rule_tree_to_expr
|
||||
~(toplevel : bool)
|
||||
~(is_subscope_var : bool)
|
||||
~(is_reentrant_var : bool)
|
||||
(ctx : ctx)
|
||||
(def_pos : Pos.t)
|
||||
(is_func : Desugared.Ast.expr Var.t option)
|
||||
@ -253,7 +253,7 @@ let rec rule_tree_to_expr
|
||||
in
|
||||
let exceptions =
|
||||
List.map
|
||||
(rule_tree_to_expr ~toplevel:false ~is_subscope_var ctx def_pos is_func)
|
||||
(rule_tree_to_expr ~toplevel:false ~is_reentrant_var ctx def_pos is_func)
|
||||
exceptions
|
||||
in
|
||||
let default =
|
||||
@ -267,9 +267,11 @@ let rec rule_tree_to_expr
|
||||
if toplevel then
|
||||
(* When we're creating a function from multiple defaults, we must check
|
||||
that the result returned by the function is not empty, unless we're
|
||||
feeding a context variabled in a called subscope. *)
|
||||
dealing with a context variable which is reentrant (either in the
|
||||
caller or callee). In this case the ErrorOnEmpty will be added later in
|
||||
the scopelang->dcalc translation. *)
|
||||
let default =
|
||||
if is_subscope_var then default else Expr.eerroronempty default emark
|
||||
if is_reentrant_var then default else Expr.eerroronempty default emark
|
||||
in
|
||||
Expr.make_abs
|
||||
[| Var.Map.find new_param ctx.var_mapping |]
|
||||
@ -334,6 +336,11 @@ let translate_def
|
||||
| OnlyInput -> true
|
||||
| _ -> false
|
||||
in
|
||||
let is_reentrant =
|
||||
match Marked.unmark io.Desugared.Ast.io_input with
|
||||
| Reentrant -> true
|
||||
| _ -> false
|
||||
in
|
||||
let top_value =
|
||||
if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then
|
||||
(* We add the bottom [false] value for conditions, only for the scope
|
||||
@ -371,10 +378,14 @@ let translate_def
|
||||
will not be provided by the calee scope, it has to be placed in the
|
||||
caller. *)
|
||||
then
|
||||
Expr.elit LEmptyError
|
||||
(Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info })
|
||||
let m = Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info } in
|
||||
let empty_error = Expr.elit LEmptyError m in
|
||||
match is_def_func_param_typ with
|
||||
| Some ty ->
|
||||
Expr.make_abs [| Var.make "_" |] empty_error [ty] (Expr.mark_pos m)
|
||||
| _ -> empty_error
|
||||
else
|
||||
rule_tree_to_expr ~toplevel:true ~is_subscope_var ctx
|
||||
rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx
|
||||
(Desugared.Ast.ScopeDef.get_position def_info)
|
||||
(Option.map (fun _ -> Var.make "param") is_def_func_param_typ)
|
||||
(match top_list, top_value with
|
||||
|
@ -82,6 +82,18 @@ let match_and_ignore_outer_reentrant_default (ctx : ctx) (e : typed expr) :
|
||||
(* context sub-scope variables *)
|
||||
let _, body = Bindlib.unmbind binder in
|
||||
body
|
||||
| EAbs { binder; _ } -> (
|
||||
(* context scope variables *)
|
||||
let _, body = Bindlib.unmbind binder in
|
||||
match Marked.unmark body with
|
||||
| EErrorOnEmpty e -> e
|
||||
| _ ->
|
||||
Errors.raise_spanned_error (Expr.pos e)
|
||||
"Internal error: this expression does not have the structure expected \
|
||||
by the VC generator:\n\
|
||||
%a"
|
||||
(Expr.format ~debug:true ctx.decl)
|
||||
e)
|
||||
| EErrorOnEmpty d ->
|
||||
d (* input subscope variables and non-input scope variable *)
|
||||
| _ ->
|
||||
|
@ -28,9 +28,9 @@ $ catala Dcalc -s A
|
||||
let A =
|
||||
λ (A_in: A_in {"f_in": integer → integer}) →
|
||||
let f : integer → integer = A_in."f_in" in
|
||||
let f1 : integer → integer = error_empty
|
||||
λ (param: integer) → ⟨f param | true ⊢
|
||||
error_empty ⟨true ⊢ param + 1⟩⟩ in
|
||||
let f1 : integer → integer =
|
||||
λ (param: integer) → error_empty
|
||||
⟨f param | true ⊢ ⟨true ⊢ param + 1⟩⟩ in
|
||||
A { }
|
||||
```
|
||||
|
||||
@ -40,8 +40,7 @@ let B =
|
||||
λ (B_in: B_in {"b_in": bool}) →
|
||||
let b : bool = B_in."b_in" in
|
||||
let a.f : integer → integer =
|
||||
λ (param: integer) → error_empty
|
||||
⟨b && param > 0 ⊢ param - 1⟩ in
|
||||
λ (param: integer) → ⟨b && param > 0 ⊢ param - 1⟩ in
|
||||
let result : A {} = A (A_in { "f_in"= a.f }) in
|
||||
B { }
|
||||
```
|
||||
|
Loading…
Reference in New Issue
Block a user