From 22062434e4901e3213eeefaee88a1dfb2f432d40 Mon Sep 17 00:00:00 2001 From: Nicolas Chataing Date: Fri, 7 Aug 2020 15:56:32 +0200 Subject: [PATCH] Add EDefault to terms, and replace default_term by term in the scope language --- src/catala/debug.ml | 7 +- src/catala/ir/lambda.ml | 44 +++---- src/catala/ir/scope.ml | 2 +- src/catala/translation/firstpass.ml | 158 +++++++++----------------- src/catala/translation/interpreter.ml | 120 +++++++++---------- 5 files changed, 131 insertions(+), 200 deletions(-) diff --git a/src/catala/debug.ml b/src/catala/debug.ml index 52b079e9..48b401db 100644 --- a/src/catala/debug.ml +++ b/src/catala/debug.ml @@ -85,9 +85,9 @@ let rec print_term (((t, _), _) : Lambda.term) : string = | EBool b -> if b then "true" else "false" | EDec (i, f) -> Printf.sprintf "%d.%d" i f | EOp op -> print_op op + | EDefault t -> print_default_term t -(** Print default term *) -let print_default_term (term : Lambda.default_term) : string = +and print_default_term (term : Lambda.default_term) : string = term.defaults |> Lambda.IntMap.bindings |> List.map (fun (_, (cond, body)) -> Printf.sprintf "\t%s => %s" (print_term cond) (print_term body)) @@ -97,8 +97,7 @@ let print_default_term (term : Lambda.default_term) : string = let print_scope (scope : Scope.scope) : string = let print_defs (defs : Scope.definition UidMap.t) : string = defs |> UidMap.bindings - |> List.map (fun (uid, term) -> - Printf.sprintf "%s:\n%s" (Uid.get_ident uid) (print_default_term term)) + |> List.map (fun (uid, term) -> Printf.sprintf "%s:\n%s" (Uid.get_ident uid) (print_term term)) |> String.concat "" in "___Variables Definition___\n" ^ print_defs scope.scope_defs ^ "___Subscope (Re)definition___\n" diff --git a/src/catala/ir/lambda.ml b/src/catala/ir/lambda.ml index 677e7eff..3d40ee8c 100644 --- a/src/catala/ir/lambda.ml +++ b/src/catala/ir/lambda.ml @@ -14,6 +14,7 @@ module UidMap = Uid.UidMap module UidSet = Uid.UidSet +module IntMap = Map.Make (Int) (* TDummy means the term is not typed *) type typ = TBool | TInt | TArrow of typ * typ | TDummy @@ -41,6 +42,15 @@ and untyped_term = | EBool of bool | EDec of int * int | EOp of op + | EDefault of default_term + +(* (x,y) in ordering means that default x has precedence over default y : if both are true then x + would be choser over y *) +and default_term = { + defaults : (term * term) IntMap.t; + ordering : (int * int) list; + nb_defaults : int; +} let untype (((term, _), _) : term) : untyped_term = term @@ -48,25 +58,9 @@ let get_pos (((_, pos), _) : term) : Pos.t = pos let get_typ ((_, typ) : term) : typ = typ -(* Default terms *) - -module IntMap = Map.Make (Int) - -type justification = term - -type consequence = term - -(* (x,y) in ordering means that default x has precedence over default y : if both are true then x - would be choser over y *) -type default_term = { - defaults : (justification * consequence) IntMap.t; - ordering : (int * int) list; - nb_defaults : int; -} - let empty_default_term : default_term = { defaults = IntMap.empty; ordering = []; nb_defaults = 0 } -let add_default (just : justification) (cons : consequence) (term : default_term) = +let add_default (just : term) (cons : term) (term : default_term) = { term with defaults = IntMap.add term.nb_defaults (just, cons) term.defaults; @@ -99,24 +93,22 @@ let merge_default_terms (lo_term : default_term) (hi_term : default_term) : defa { defaults; ordering = prec @ prec'; nb_defaults = n + n' } (** Returns the free variables of a term *) -let rec term_free_vars (term : term) : UidSet.t = +let rec term_fv (term : term) : UidSet.t = match untype term with | EVar uid -> UidSet.singleton uid | EFun (bindings, body) -> - let body_fv = term_free_vars body in + let body_fv = term_fv body in let bindings = bindings |> List.map (fun (x, _) -> x) |> UidSet.of_list in UidSet.diff body_fv bindings - | EApp (f, args) -> - List.fold_left (fun fv arg -> UidSet.union fv (term_free_vars arg)) (term_free_vars f) args + | EApp (f, args) -> List.fold_left (fun fv arg -> UidSet.union fv (term_fv arg)) (term_fv f) args | EIfThenElse (t_if, t_then, t_else) -> - UidSet.union (term_free_vars t_if) - (UidSet.union (term_free_vars t_then) (term_free_vars t_else)) + UidSet.union (term_fv t_if) (UidSet.union (term_fv t_then) (term_fv t_else)) + | EDefault default -> default_term_fv default | _ -> UidSet.empty -(** Returns the free variable of a default term *) -let default_term_fv (term : default_term) : UidSet.t = +and default_term_fv (term : default_term) : UidSet.t = IntMap.fold (fun _ (cond, body) -> - let fv = UidSet.union (term_free_vars cond) (term_free_vars body) in + let fv = UidSet.union (term_fv cond) (term_fv body) in UidSet.union fv) term.defaults UidSet.empty diff --git a/src/catala/ir/scope.ml b/src/catala/ir/scope.ml index 7e50a158..5bb57a79 100644 --- a/src/catala/ir/scope.ml +++ b/src/catala/ir/scope.ml @@ -17,7 +17,7 @@ module UidMap = Uid.UidMap (* Scopes *) type binder = string Pos.marked -type definition = Lambda.default_term +type definition = Lambda.term type assertion = Lambda.term diff --git a/src/catala/translation/firstpass.ml b/src/catala/translation/firstpass.ml index ffa96e89..e6ac9f28 100644 --- a/src/catala/translation/firstpass.ml +++ b/src/catala/translation/firstpass.ml @@ -68,12 +68,11 @@ let rec expr_to_lambda ?(subdef : Uid.t option) (scope : Uid.t) (ctxt : Context. | _ -> assert false (** Checks that a term is well typed and annotate it *) -let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.term * Lambda.typ = +let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.term = match t with | EVar uid -> let typ = Context.get_uid_typ ctxt uid in - let term = ((EVar uid, pos), typ) in - (term, typ) + ((EVar uid, pos), typ) | EFun (bindings, body) -> (* Note that given the context formation process, the binder will already be present in the context (since we are working with uids), however it's added there for the sake of clarity *) @@ -85,16 +84,19 @@ let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.t ctxt.data bindings in - let body, ret_typ = typing { ctxt with data = ctxt_data } body in + let body = typing { ctxt with data = ctxt_data } body in + let ret_typ = Lambda.get_typ body in let rec build_typ = function | [] -> ret_typ | (_, arg_t) :: args -> TArrow (arg_t, build_typ args) in let fun_typ = build_typ bindings in - (((EFun (bindings, body), pos), fun_typ), fun_typ) + ((EFun (bindings, body), pos), fun_typ) | EApp (f, args) -> - let f, f_typ = typing ctxt f in - let args, args_typ = args |> List.map (typing ctxt) |> List.split in + let f = typing ctxt f in + let f_typ = Lambda.get_typ f in + let args = List.map (typing ctxt) args in + let args_typ = List.map Lambda.get_typ args in let rec check_arrow_typ f_typ args_typ = match (f_typ, args_typ) with | typ, [] -> typ @@ -103,16 +105,19 @@ let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.t | _ -> assert false in let ret_typ = check_arrow_typ f_typ args_typ in - (((EApp (f, args), pos), ret_typ), ret_typ) + ((EApp (f, args), pos), ret_typ) | EIfThenElse (t_if, t_then, t_else) -> - let t_if, typ_if = typing ctxt t_if in - let t_then, typ_then = typing ctxt t_then in - let t_else, typ_else = typing ctxt t_else in + let t_if = typing ctxt t_if in + let typ_if = Lambda.get_typ t_if in + let t_then = typing ctxt t_then in + let typ_then = Lambda.get_typ t_then in + let t_else = typing ctxt t_else in + let typ_else = Lambda.get_typ t_else in if typ_if <> TBool then assert false else if typ_then <> typ_else then assert false - else (((EIfThenElse (t_if, t_then, t_else), pos), typ_then), typ_then) - | EInt _ | EDec _ -> (((t, pos), TInt), TInt) - | EBool _ -> (((t, pos), TBool), TBool) + else ((EIfThenElse (t_if, t_then, t_else), pos), typ_then) + | EInt _ | EDec _ -> ((t, pos), TInt) + | EBool _ -> ((t, pos), TBool) | EOp op -> let typ = match op with @@ -124,7 +129,23 @@ let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.t | Unop Minus -> TArrow (TInt, TInt) | Unop Not -> TArrow (TBool, TBool) in - (((t, pos), typ), typ) + ((t, pos), typ) + | EDefault t -> + let defaults = + IntMap.map + (fun (just, cons) -> + let just = typing ctxt just in + if Lambda.get_typ just <> TBool then + let cons = typing ctxt cons in + (just, cons) + else assert false) + t.defaults + in + let typ_cons = IntMap.choose defaults |> snd |> snd |> Lambda.get_typ in + IntMap.iter + (fun _ (_, cons) -> if Lambda.get_typ cons <> typ_cons then assert false else ()) + defaults; + ((EDefault { t with defaults }, pos), typ_cons) (* Translation from the parsed ast to the scope language *) @@ -137,93 +158,24 @@ let merge_conditions (precond : Lambda.term option) (cond : Lambda.term option) | Some cond, None | None, Some cond -> cond | None, None -> ((EBool true, default_pos), TBool) -(* Process a definition *) -let process_def (precond : Lambda.term option) (scope : Uid.t) (ctxt : Context.context) - (prgm : Scope.program) (def : Ast.definition) : Scope.program = - (* We first check either it is a variable or a subvariable *) - let scope_prgm = UidMap.find scope prgm in - let pos = Pos.get_position def.definition_name in - let scope_prgm = - match Pos.unmark def.definition_name with - | [ x ] -> - let x_uid = Context.get_var_uid scope ctxt x in - let x_def = - match UidMap.find_opt x_uid scope_prgm.scope_defs with - | None -> Lambda.empty_default_term - | Some def -> def - in - (* ctxt redefines just the ident lookup for the argument binding (in case x is a function *) - let ctxt, arg_uid = Context.add_binding ctxt scope x_uid def.definition_parameter in - (* Process the condition *) - let cond = - match def.definition_condition with - | Some cond -> - let cond, typ = typing ctxt (expr_to_lambda scope ctxt cond) in - if typ = TBool then Some cond else assert false - | None -> None - in - let condition = - merge_conditions precond cond (Pos.get_position def.definition_name) |> typing ctxt |> fst - in - let body = expr_to_lambda scope ctxt def.definition_expr in - (* In case it is a function, wrap it in a EFun*) - let body = - ( match arg_uid with - | None -> body - | Some arg_uid -> - let binding = (arg_uid, Context.get_uid_typ ctxt arg_uid) in - ((EFun ([ binding ], body), Pos.get_position def.definition_expr), TDummy) ) - |> typing ctxt |> fst - in - let x_def = Lambda.add_default condition body x_def in - { scope_prgm with scope_defs = UidMap.add x_uid x_def scope_prgm.scope_defs } - | [ y; x ] -> - let subscope_uid, scope_ref = Context.get_subscope_uid scope ctxt y in - let x_uid = Context.get_var_uid scope_ref ctxt x in - let y_subdef = - match UidMap.find_opt subscope_uid scope_prgm.scope_sub_defs with - | Some defs -> defs - | None -> UidMap.empty - in - let x_redef = - match UidMap.find_opt x_uid y_subdef with - | None -> Lambda.empty_default_term - | Some redef -> redef - in - (* ctxt redefines just the ident lookup for the argument binding (in case x is a function *) - let ctxt, arg_uid = Context.add_binding ctxt scope x_uid def.definition_parameter in - (* Process cond with the subdef argument*) - let cond = - match def.definition_condition with - | Some cond -> - let cond, typ = expr_to_lambda ~subdef:scope_ref scope ctxt cond |> typing ctxt in - if typ = TBool then Some cond else assert false - | None -> None - in - let condition = - merge_conditions precond cond (Pos.get_position def.definition_name) |> typing ctxt |> fst - in - let body = expr_to_lambda ~subdef:scope_ref scope ctxt def.definition_expr in - (* In case it is a function, wrap it in a EFun*) - let body = - ( match arg_uid with - | None -> body - | Some arg_uid -> - let binding = (arg_uid, Context.get_uid_typ ctxt arg_uid) in - ((EFun ([ binding ], body), Pos.get_position def.definition_expr), TDummy) ) - |> typing ctxt |> fst - in - let x_redef = Lambda.add_default condition body x_redef in - let y_subdef = UidMap.add x_uid x_redef y_subdef in - { - scope_prgm with - scope_sub_defs = UidMap.add subscope_uid y_subdef scope_prgm.scope_sub_defs; - } - | _ -> - Cli.debug_print (Printf.sprintf "Structs are not handled yet.\n%s\n" (Pos.to_string pos)); - assert false +let process_default (ctxt : Context.context) (scope : Uid.t) (def : Lambda.default_term) + (precond : Lambda.term option) (just : Ast.expression Pos.marked option) + (body : Ast.expression Pos.marked) (subdef : Uid.t option) : Lambda.default_term = + let just = + match just with + | Some cond -> + let cond = expr_to_lambda ?subdef scope ctxt cond |> typing ctxt in + if Lambda.get_typ cond = TBool then Some cond else assert false + | None -> None in - UidMap.add scope scope_prgm prgm + let condition = merge_conditions precond just (Pos.get_position body) |> typing ctxt in + let body = expr_to_lambda ?subdef scope ctxt body |> typing ctxt in + Lambda.add_default condition body def + +(* Process a definition *) +let process_def (_precond : Lambda.term option) (_scope : Uid.t) (_ctxt : Context.context) + (_prgm : Scope.program) (_def : Ast.definition) : Scope.program = + assert false (** Process a rule from the surface language *) let process_rule (precond : Lambda.term option) (scope : Uid.t) (ctxt : Context.context) @@ -260,8 +212,8 @@ let process_scope_use (ctxt : Context.context) (prgm : Scope.program) (use : Ast match use.scope_use_condition with | Some expr -> let untyped_term = expr_to_lambda scope_uid ctxt expr in - let term, typ = typing ctxt untyped_term in - if typ = TBool then Some term else assert false + let term = typing ctxt untyped_term in + if Lambda.get_typ term = TBool then Some term else assert false | None -> None in List.fold_left (process_scope_use_item cond scope_uid ctxt) prgm use.scope_use_items diff --git a/src/catala/translation/interpreter.ml b/src/catala/translation/interpreter.ml index 2e193188..10d00db2 100644 --- a/src/catala/translation/interpreter.ml +++ b/src/catala/translation/interpreter.ml @@ -26,20 +26,29 @@ type exec_context = Lambda.untyped_term UidMap.t let empty_exec_ctxt = UidMap.empty -let rec substitute (var : uid) (value : Lambda.untyped_term) (term : Lambda.term) : Lambda.term = - let (term', pos), typ = term in - let subst = substitute var value in - let subst_term = - match term' with - | EVar uid -> if var = uid then value else term' - | EFun (bindings, body) -> EFun (bindings, subst body) - | EApp (f, args) -> EApp (subst f, List.map subst args) - | EIfThenElse (t_if, t_then, t_else) -> EIfThenElse (subst t_if, subst t_then, subst t_else) - | EInt _ | EBool _ | EDec _ | EOp _ -> term' - in - ((subst_term, pos), typ) +let raise_default_conflict (id : string Pos.marked) (true_pos : Pos.t list) (false_pos : Pos.t list) + = + if List.length true_pos = 0 then + let justifications : (string option * Pos.t) list = + List.map (fun pos -> (Some "This justification is false:", pos)) false_pos + in + Errors.raise_multispanned_error + (Printf.sprintf "Default logic error for variable %s: no justification is true." + (Pos.unmark id)) + ( ( Some (Printf.sprintf "The error concerns this variable %s" (Pos.unmark id)), + Pos.get_position id ) + :: justifications ) + else + let justifications : (string option * Pos.t) list = + List.map (fun pos -> (Some "This justification is true:", pos)) true_pos + in + Errors.raise_multispanned_error + "Default logic conflict, multiple justifications are true but are not related by a precedence" + ( ( Some (Printf.sprintf "The conflict concerns this variable %s" (Pos.unmark id)), + Pos.get_position id ) + :: justifications ) -let rec eval_term (exec_ctxt : exec_context) (term : Lambda.term) : Lambda.term = +let rec eval_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambda.term) : Lambda.term = let (term, pos), typ = term in let evaled_term = match term with @@ -54,18 +63,20 @@ let rec eval_term (exec_ctxt : exec_context) (term : Lambda.term) : Lambda.term assert false ) | EApp (f, args) -> ( (* First evaluate and match the function body *) - let f = f |> eval_term exec_ctxt |> Lambda.untype in + let f = f |> eval_term top_uid exec_ctxt |> Lambda.untype in match f with | EFun (bindings, body) -> - let body = + let exec_ctxt = List.fold_left2 - (fun body arg (uid, _) -> - substitute uid (arg |> eval_term exec_ctxt |> Lambda.untype) body) - body args bindings + (fun ctxt arg (uid, _) -> + UidMap.add uid (arg |> eval_term top_uid exec_ctxt |> Lambda.untype) ctxt) + exec_ctxt args bindings in - eval_term exec_ctxt body |> Lambda.untype + eval_term top_uid exec_ctxt body |> Lambda.untype | EOp op -> ( - let args = List.map (fun arg -> arg |> eval_term exec_ctxt |> Lambda.untype) args in + let args = + List.map (fun arg -> arg |> eval_term top_uid exec_ctxt |> Lambda.untype) args + in match op with | Binop binop -> ( match binop with @@ -105,21 +116,29 @@ let rec eval_term (exec_ctxt : exec_context) (term : Lambda.term) : Lambda.term | Unop Not -> ( match args with [ EBool b ] -> EBool (not b) | _ -> assert false ) ) | _ -> assert false ) | EIfThenElse (t_if, t_then, t_else) -> - ( match eval_term exec_ctxt t_if |> Lambda.untype with - | EBool b -> if b then eval_term exec_ctxt t_then else eval_term exec_ctxt t_else + ( match eval_term top_uid exec_ctxt t_if |> Lambda.untype with + | EBool b -> + if b then eval_term top_uid exec_ctxt t_then else eval_term top_uid exec_ctxt t_else | _ -> assert false ) |> Lambda.untype + | EDefault t -> ( + match eval_default_term top_uid exec_ctxt t with + | Ok value -> value |> Lambda.untype + | Error (true_pos, false_pos) -> + raise_default_conflict (Uid.get_ident top_uid, Uid.get_pos top_uid) true_pos false_pos ) in ((evaled_term, pos), typ) (* Evaluates a default term : see the formalization for an insight about this operation *) -let eval_default_term (exec_ctxt : exec_context) (term : Lambda.default_term) : +and eval_default_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambda.default_term) : (Lambda.term, Pos.t list * Pos.t list) result = (* First filter out the term which justification are false *) let candidates : 'a IntMap.t = IntMap.filter (fun _ (cond, _) -> - match eval_term exec_ctxt cond |> Lambda.untype with EBool b -> b | _ -> assert false) + match eval_term top_uid exec_ctxt cond |> Lambda.untype with + | EBool b -> b + | _ -> assert false) term.defaults in (* Now filter out the terms that have a predecessor which justification is true *) @@ -133,7 +152,7 @@ let eval_default_term (exec_ctxt : exec_context) (term : Lambda.default_term) : match ISet.elements chosen_one with | [ x ] -> let _, cons = IntMap.find x term.defaults in - Ok (eval_term exec_ctxt cons) + Ok (eval_term top_uid exec_ctxt cons) | xs -> let true_pos = xs |> List.map (fun x -> IntMap.find x term.defaults |> fst |> Lambda.get_pos) @@ -164,7 +183,7 @@ let build_scope_schedule (ctxt : Context.context) (scope : Scope.scope) : G.t = -> var *) UidMap.iter (fun var_uid def -> - let fv = Lambda.default_term_fv def in + let fv = Lambda.term_fv def in UidSet.iter (fun uid -> if Context.belongs_to ctxt uid scope_uid then @@ -184,7 +203,7 @@ let build_scope_schedule (ctxt : Context.context) (scope : Scope.scope) : G.t = (fun sub_scope_uid defs -> UidMap.iter (fun _ def -> - let fv = Lambda.default_term_fv def in + let fv = Lambda.term_fv def in UidSet.iter (fun var_uid -> (* Process only uid from the current scope (not the subscope) *) @@ -196,41 +215,13 @@ let build_scope_schedule (ctxt : Context.context) (scope : Scope.scope) : G.t = scope.scope_sub_defs; g -let merge_var_redefs (subscope : Scope.scope) (redefs : Scope.definition UidMap.t) : Scope.scope = - { - subscope with - scope_defs = - UidMap.fold - (fun uid new_def sub_defs -> - match UidMap.find_opt uid sub_defs with - | None -> UidMap.add uid new_def sub_defs - | Some old_def -> - let def = Lambda.merge_default_terms old_def new_def in - UidMap.add uid def sub_defs) - redefs subscope.scope_defs; - } +let merge_var_redefs (_subscope : Scope.scope) (_redefs : Scope.definition UidMap.t) : Scope.scope = + assert false -let raise_default_conflict (id : string Pos.marked) (true_pos : Pos.t list) (false_pos : Pos.t list) - = - if List.length true_pos = 0 then - let justifications : (string option * Pos.t) list = - List.map (fun pos -> (Some "This justification is false:", pos)) false_pos - in - Errors.raise_multispanned_error - (Printf.sprintf "Default logic error for variable %s: no justification is true." - (Pos.unmark id)) - ( ( Some (Printf.sprintf "The error concerns this variable %s" (Pos.unmark id)), - Pos.get_position id ) - :: justifications ) - else - let justifications : (string option * Pos.t) list = - List.map (fun pos -> (Some "This justification is true:", pos)) true_pos - in - Errors.raise_multispanned_error - "Default logic conflict, multiple justifications are true but are not related by a precedence" - ( ( Some (Printf.sprintf "The conflict concerns this variable %s" (Pos.unmark id)), - Pos.get_position id ) - :: justifications ) +(*{ subscope with scope_defs = UidMap.fold (fun uid new_def sub_defs -> match UidMap.find_opt uid + sub_defs with | None -> UidMap.add uid new_def sub_defs | Some old_def -> let def = + Lambda.merge_default_terms old_def new_def in UidMap.add uid def sub_defs) redefs + subscope.scope_defs; }*) let rec execute_scope ?(exec_context = empty_exec_ctxt) (ctxt : Context.context) (prgm : Scope.program) (scope_prgm : Scope.scope) : exec_context = @@ -246,11 +237,8 @@ let rec execute_scope ?(exec_context = empty_exec_ctxt) (ctxt : Context.context) match Context.get_uid_sort ctxt uid with | IdScopeVar _ -> ( match UidMap.find_opt uid scope_prgm.scope_defs with - | Some def -> ( - match eval_default_term exec_context def with - | Ok value -> UidMap.add uid (Lambda.untype value) exec_context - | Error (true_pos, false_pos) -> - raise_default_conflict (Uid.get_ident uid, Uid.get_pos uid) true_pos false_pos ) + | Some def -> + UidMap.add uid (eval_term uid exec_context def |> Lambda.untype) exec_context | None -> Cli.error_print (Printf.sprintf "Variable %s is undefined in scope %s\n\n%s\n\n%s"