From 22062434e4901e3213eeefaee88a1dfb2f432d40 Mon Sep 17 00:00:00 2001 From: Nicolas Chataing Date: Fri, 7 Aug 2020 15:56:32 +0200 Subject: [PATCH 1/3] Add EDefault to terms, and replace default_term by term in the scope language --- src/catala/debug.ml | 7 +- src/catala/ir/lambda.ml | 44 +++---- src/catala/ir/scope.ml | 2 +- src/catala/translation/firstpass.ml | 158 +++++++++----------------- src/catala/translation/interpreter.ml | 120 +++++++++---------- 5 files changed, 131 insertions(+), 200 deletions(-) diff --git a/src/catala/debug.ml b/src/catala/debug.ml index 52b079e9..48b401db 100644 --- a/src/catala/debug.ml +++ b/src/catala/debug.ml @@ -85,9 +85,9 @@ let rec print_term (((t, _), _) : Lambda.term) : string = | EBool b -> if b then "true" else "false" | EDec (i, f) -> Printf.sprintf "%d.%d" i f | EOp op -> print_op op + | EDefault t -> print_default_term t -(** Print default term *) -let print_default_term (term : Lambda.default_term) : string = +and print_default_term (term : Lambda.default_term) : string = term.defaults |> Lambda.IntMap.bindings |> List.map (fun (_, (cond, body)) -> Printf.sprintf "\t%s => %s" (print_term cond) (print_term body)) @@ -97,8 +97,7 @@ let print_default_term (term : Lambda.default_term) : string = let print_scope (scope : Scope.scope) : string = let print_defs (defs : Scope.definition UidMap.t) : string = defs |> UidMap.bindings - |> List.map (fun (uid, term) -> - Printf.sprintf "%s:\n%s" (Uid.get_ident uid) (print_default_term term)) + |> List.map (fun (uid, term) -> Printf.sprintf "%s:\n%s" (Uid.get_ident uid) (print_term term)) |> String.concat "" in "___Variables Definition___\n" ^ print_defs scope.scope_defs ^ "___Subscope (Re)definition___\n" diff --git a/src/catala/ir/lambda.ml b/src/catala/ir/lambda.ml index 677e7eff..3d40ee8c 100644 --- a/src/catala/ir/lambda.ml +++ b/src/catala/ir/lambda.ml @@ -14,6 +14,7 @@ module UidMap = Uid.UidMap module UidSet = Uid.UidSet +module IntMap = Map.Make (Int) (* TDummy means the term is not typed *) type typ = TBool | TInt | TArrow of typ * typ | TDummy @@ -41,6 +42,15 @@ and untyped_term = | EBool of bool | EDec of int * int | EOp of op + | EDefault of default_term + +(* (x,y) in ordering means that default x has precedence over default y : if both are true then x + would be choser over y *) +and default_term = { + defaults : (term * term) IntMap.t; + ordering : (int * int) list; + nb_defaults : int; +} let untype (((term, _), _) : term) : untyped_term = term @@ -48,25 +58,9 @@ let get_pos (((_, pos), _) : term) : Pos.t = pos let get_typ ((_, typ) : term) : typ = typ -(* Default terms *) - -module IntMap = Map.Make (Int) - -type justification = term - -type consequence = term - -(* (x,y) in ordering means that default x has precedence over default y : if both are true then x - would be choser over y *) -type default_term = { - defaults : (justification * consequence) IntMap.t; - ordering : (int * int) list; - nb_defaults : int; -} - let empty_default_term : default_term = { defaults = IntMap.empty; ordering = []; nb_defaults = 0 } -let add_default (just : justification) (cons : consequence) (term : default_term) = +let add_default (just : term) (cons : term) (term : default_term) = { term with defaults = IntMap.add term.nb_defaults (just, cons) term.defaults; @@ -99,24 +93,22 @@ let merge_default_terms (lo_term : default_term) (hi_term : default_term) : defa { defaults; ordering = prec @ prec'; nb_defaults = n + n' } (** Returns the free variables of a term *) -let rec term_free_vars (term : term) : UidSet.t = +let rec term_fv (term : term) : UidSet.t = match untype term with | EVar uid -> UidSet.singleton uid | EFun (bindings, body) -> - let body_fv = term_free_vars body in + let body_fv = term_fv body in let bindings = bindings |> List.map (fun (x, _) -> x) |> UidSet.of_list in UidSet.diff body_fv bindings - | EApp (f, args) -> - List.fold_left (fun fv arg -> UidSet.union fv (term_free_vars arg)) (term_free_vars f) args + | EApp (f, args) -> List.fold_left (fun fv arg -> UidSet.union fv (term_fv arg)) (term_fv f) args | EIfThenElse (t_if, t_then, t_else) -> - UidSet.union (term_free_vars t_if) - (UidSet.union (term_free_vars t_then) (term_free_vars t_else)) + UidSet.union (term_fv t_if) (UidSet.union (term_fv t_then) (term_fv t_else)) + | EDefault default -> default_term_fv default | _ -> UidSet.empty -(** Returns the free variable of a default term *) -let default_term_fv (term : default_term) : UidSet.t = +and default_term_fv (term : default_term) : UidSet.t = IntMap.fold (fun _ (cond, body) -> - let fv = UidSet.union (term_free_vars cond) (term_free_vars body) in + let fv = UidSet.union (term_fv cond) (term_fv body) in UidSet.union fv) term.defaults UidSet.empty diff --git a/src/catala/ir/scope.ml b/src/catala/ir/scope.ml index 7e50a158..5bb57a79 100644 --- a/src/catala/ir/scope.ml +++ b/src/catala/ir/scope.ml @@ -17,7 +17,7 @@ module UidMap = Uid.UidMap (* Scopes *) type binder = string Pos.marked -type definition = Lambda.default_term +type definition = Lambda.term type assertion = Lambda.term diff --git a/src/catala/translation/firstpass.ml b/src/catala/translation/firstpass.ml index ffa96e89..e6ac9f28 100644 --- a/src/catala/translation/firstpass.ml +++ b/src/catala/translation/firstpass.ml @@ -68,12 +68,11 @@ let rec expr_to_lambda ?(subdef : Uid.t option) (scope : Uid.t) (ctxt : Context. | _ -> assert false (** Checks that a term is well typed and annotate it *) -let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.term * Lambda.typ = +let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.term = match t with | EVar uid -> let typ = Context.get_uid_typ ctxt uid in - let term = ((EVar uid, pos), typ) in - (term, typ) + ((EVar uid, pos), typ) | EFun (bindings, body) -> (* Note that given the context formation process, the binder will already be present in the context (since we are working with uids), however it's added there for the sake of clarity *) @@ -85,16 +84,19 @@ let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.t ctxt.data bindings in - let body, ret_typ = typing { ctxt with data = ctxt_data } body in + let body = typing { ctxt with data = ctxt_data } body in + let ret_typ = Lambda.get_typ body in let rec build_typ = function | [] -> ret_typ | (_, arg_t) :: args -> TArrow (arg_t, build_typ args) in let fun_typ = build_typ bindings in - (((EFun (bindings, body), pos), fun_typ), fun_typ) + ((EFun (bindings, body), pos), fun_typ) | EApp (f, args) -> - let f, f_typ = typing ctxt f in - let args, args_typ = args |> List.map (typing ctxt) |> List.split in + let f = typing ctxt f in + let f_typ = Lambda.get_typ f in + let args = List.map (typing ctxt) args in + let args_typ = List.map Lambda.get_typ args in let rec check_arrow_typ f_typ args_typ = match (f_typ, args_typ) with | typ, [] -> typ @@ -103,16 +105,19 @@ let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.t | _ -> assert false in let ret_typ = check_arrow_typ f_typ args_typ in - (((EApp (f, args), pos), ret_typ), ret_typ) + ((EApp (f, args), pos), ret_typ) | EIfThenElse (t_if, t_then, t_else) -> - let t_if, typ_if = typing ctxt t_if in - let t_then, typ_then = typing ctxt t_then in - let t_else, typ_else = typing ctxt t_else in + let t_if = typing ctxt t_if in + let typ_if = Lambda.get_typ t_if in + let t_then = typing ctxt t_then in + let typ_then = Lambda.get_typ t_then in + let t_else = typing ctxt t_else in + let typ_else = Lambda.get_typ t_else in if typ_if <> TBool then assert false else if typ_then <> typ_else then assert false - else (((EIfThenElse (t_if, t_then, t_else), pos), typ_then), typ_then) - | EInt _ | EDec _ -> (((t, pos), TInt), TInt) - | EBool _ -> (((t, pos), TBool), TBool) + else ((EIfThenElse (t_if, t_then, t_else), pos), typ_then) + | EInt _ | EDec _ -> ((t, pos), TInt) + | EBool _ -> ((t, pos), TBool) | EOp op -> let typ = match op with @@ -124,7 +129,23 @@ let rec typing (ctxt : Context.context) (((t, pos), _) : Lambda.term) : Lambda.t | Unop Minus -> TArrow (TInt, TInt) | Unop Not -> TArrow (TBool, TBool) in - (((t, pos), typ), typ) + ((t, pos), typ) + | EDefault t -> + let defaults = + IntMap.map + (fun (just, cons) -> + let just = typing ctxt just in + if Lambda.get_typ just <> TBool then + let cons = typing ctxt cons in + (just, cons) + else assert false) + t.defaults + in + let typ_cons = IntMap.choose defaults |> snd |> snd |> Lambda.get_typ in + IntMap.iter + (fun _ (_, cons) -> if Lambda.get_typ cons <> typ_cons then assert false else ()) + defaults; + ((EDefault { t with defaults }, pos), typ_cons) (* Translation from the parsed ast to the scope language *) @@ -137,93 +158,24 @@ let merge_conditions (precond : Lambda.term option) (cond : Lambda.term option) | Some cond, None | None, Some cond -> cond | None, None -> ((EBool true, default_pos), TBool) -(* Process a definition *) -let process_def (precond : Lambda.term option) (scope : Uid.t) (ctxt : Context.context) - (prgm : Scope.program) (def : Ast.definition) : Scope.program = - (* We first check either it is a variable or a subvariable *) - let scope_prgm = UidMap.find scope prgm in - let pos = Pos.get_position def.definition_name in - let scope_prgm = - match Pos.unmark def.definition_name with - | [ x ] -> - let x_uid = Context.get_var_uid scope ctxt x in - let x_def = - match UidMap.find_opt x_uid scope_prgm.scope_defs with - | None -> Lambda.empty_default_term - | Some def -> def - in - (* ctxt redefines just the ident lookup for the argument binding (in case x is a function *) - let ctxt, arg_uid = Context.add_binding ctxt scope x_uid def.definition_parameter in - (* Process the condition *) - let cond = - match def.definition_condition with - | Some cond -> - let cond, typ = typing ctxt (expr_to_lambda scope ctxt cond) in - if typ = TBool then Some cond else assert false - | None -> None - in - let condition = - merge_conditions precond cond (Pos.get_position def.definition_name) |> typing ctxt |> fst - in - let body = expr_to_lambda scope ctxt def.definition_expr in - (* In case it is a function, wrap it in a EFun*) - let body = - ( match arg_uid with - | None -> body - | Some arg_uid -> - let binding = (arg_uid, Context.get_uid_typ ctxt arg_uid) in - ((EFun ([ binding ], body), Pos.get_position def.definition_expr), TDummy) ) - |> typing ctxt |> fst - in - let x_def = Lambda.add_default condition body x_def in - { scope_prgm with scope_defs = UidMap.add x_uid x_def scope_prgm.scope_defs } - | [ y; x ] -> - let subscope_uid, scope_ref = Context.get_subscope_uid scope ctxt y in - let x_uid = Context.get_var_uid scope_ref ctxt x in - let y_subdef = - match UidMap.find_opt subscope_uid scope_prgm.scope_sub_defs with - | Some defs -> defs - | None -> UidMap.empty - in - let x_redef = - match UidMap.find_opt x_uid y_subdef with - | None -> Lambda.empty_default_term - | Some redef -> redef - in - (* ctxt redefines just the ident lookup for the argument binding (in case x is a function *) - let ctxt, arg_uid = Context.add_binding ctxt scope x_uid def.definition_parameter in - (* Process cond with the subdef argument*) - let cond = - match def.definition_condition with - | Some cond -> - let cond, typ = expr_to_lambda ~subdef:scope_ref scope ctxt cond |> typing ctxt in - if typ = TBool then Some cond else assert false - | None -> None - in - let condition = - merge_conditions precond cond (Pos.get_position def.definition_name) |> typing ctxt |> fst - in - let body = expr_to_lambda ~subdef:scope_ref scope ctxt def.definition_expr in - (* In case it is a function, wrap it in a EFun*) - let body = - ( match arg_uid with - | None -> body - | Some arg_uid -> - let binding = (arg_uid, Context.get_uid_typ ctxt arg_uid) in - ((EFun ([ binding ], body), Pos.get_position def.definition_expr), TDummy) ) - |> typing ctxt |> fst - in - let x_redef = Lambda.add_default condition body x_redef in - let y_subdef = UidMap.add x_uid x_redef y_subdef in - { - scope_prgm with - scope_sub_defs = UidMap.add subscope_uid y_subdef scope_prgm.scope_sub_defs; - } - | _ -> - Cli.debug_print (Printf.sprintf "Structs are not handled yet.\n%s\n" (Pos.to_string pos)); - assert false +let process_default (ctxt : Context.context) (scope : Uid.t) (def : Lambda.default_term) + (precond : Lambda.term option) (just : Ast.expression Pos.marked option) + (body : Ast.expression Pos.marked) (subdef : Uid.t option) : Lambda.default_term = + let just = + match just with + | Some cond -> + let cond = expr_to_lambda ?subdef scope ctxt cond |> typing ctxt in + if Lambda.get_typ cond = TBool then Some cond else assert false + | None -> None in - UidMap.add scope scope_prgm prgm + let condition = merge_conditions precond just (Pos.get_position body) |> typing ctxt in + let body = expr_to_lambda ?subdef scope ctxt body |> typing ctxt in + Lambda.add_default condition body def + +(* Process a definition *) +let process_def (_precond : Lambda.term option) (_scope : Uid.t) (_ctxt : Context.context) + (_prgm : Scope.program) (_def : Ast.definition) : Scope.program = + assert false (** Process a rule from the surface language *) let process_rule (precond : Lambda.term option) (scope : Uid.t) (ctxt : Context.context) @@ -260,8 +212,8 @@ let process_scope_use (ctxt : Context.context) (prgm : Scope.program) (use : Ast match use.scope_use_condition with | Some expr -> let untyped_term = expr_to_lambda scope_uid ctxt expr in - let term, typ = typing ctxt untyped_term in - if typ = TBool then Some term else assert false + let term = typing ctxt untyped_term in + if Lambda.get_typ term = TBool then Some term else assert false | None -> None in List.fold_left (process_scope_use_item cond scope_uid ctxt) prgm use.scope_use_items diff --git a/src/catala/translation/interpreter.ml b/src/catala/translation/interpreter.ml index 2e193188..10d00db2 100644 --- a/src/catala/translation/interpreter.ml +++ b/src/catala/translation/interpreter.ml @@ -26,20 +26,29 @@ type exec_context = Lambda.untyped_term UidMap.t let empty_exec_ctxt = UidMap.empty -let rec substitute (var : uid) (value : Lambda.untyped_term) (term : Lambda.term) : Lambda.term = - let (term', pos), typ = term in - let subst = substitute var value in - let subst_term = - match term' with - | EVar uid -> if var = uid then value else term' - | EFun (bindings, body) -> EFun (bindings, subst body) - | EApp (f, args) -> EApp (subst f, List.map subst args) - | EIfThenElse (t_if, t_then, t_else) -> EIfThenElse (subst t_if, subst t_then, subst t_else) - | EInt _ | EBool _ | EDec _ | EOp _ -> term' - in - ((subst_term, pos), typ) +let raise_default_conflict (id : string Pos.marked) (true_pos : Pos.t list) (false_pos : Pos.t list) + = + if List.length true_pos = 0 then + let justifications : (string option * Pos.t) list = + List.map (fun pos -> (Some "This justification is false:", pos)) false_pos + in + Errors.raise_multispanned_error + (Printf.sprintf "Default logic error for variable %s: no justification is true." + (Pos.unmark id)) + ( ( Some (Printf.sprintf "The error concerns this variable %s" (Pos.unmark id)), + Pos.get_position id ) + :: justifications ) + else + let justifications : (string option * Pos.t) list = + List.map (fun pos -> (Some "This justification is true:", pos)) true_pos + in + Errors.raise_multispanned_error + "Default logic conflict, multiple justifications are true but are not related by a precedence" + ( ( Some (Printf.sprintf "The conflict concerns this variable %s" (Pos.unmark id)), + Pos.get_position id ) + :: justifications ) -let rec eval_term (exec_ctxt : exec_context) (term : Lambda.term) : Lambda.term = +let rec eval_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambda.term) : Lambda.term = let (term, pos), typ = term in let evaled_term = match term with @@ -54,18 +63,20 @@ let rec eval_term (exec_ctxt : exec_context) (term : Lambda.term) : Lambda.term assert false ) | EApp (f, args) -> ( (* First evaluate and match the function body *) - let f = f |> eval_term exec_ctxt |> Lambda.untype in + let f = f |> eval_term top_uid exec_ctxt |> Lambda.untype in match f with | EFun (bindings, body) -> - let body = + let exec_ctxt = List.fold_left2 - (fun body arg (uid, _) -> - substitute uid (arg |> eval_term exec_ctxt |> Lambda.untype) body) - body args bindings + (fun ctxt arg (uid, _) -> + UidMap.add uid (arg |> eval_term top_uid exec_ctxt |> Lambda.untype) ctxt) + exec_ctxt args bindings in - eval_term exec_ctxt body |> Lambda.untype + eval_term top_uid exec_ctxt body |> Lambda.untype | EOp op -> ( - let args = List.map (fun arg -> arg |> eval_term exec_ctxt |> Lambda.untype) args in + let args = + List.map (fun arg -> arg |> eval_term top_uid exec_ctxt |> Lambda.untype) args + in match op with | Binop binop -> ( match binop with @@ -105,21 +116,29 @@ let rec eval_term (exec_ctxt : exec_context) (term : Lambda.term) : Lambda.term | Unop Not -> ( match args with [ EBool b ] -> EBool (not b) | _ -> assert false ) ) | _ -> assert false ) | EIfThenElse (t_if, t_then, t_else) -> - ( match eval_term exec_ctxt t_if |> Lambda.untype with - | EBool b -> if b then eval_term exec_ctxt t_then else eval_term exec_ctxt t_else + ( match eval_term top_uid exec_ctxt t_if |> Lambda.untype with + | EBool b -> + if b then eval_term top_uid exec_ctxt t_then else eval_term top_uid exec_ctxt t_else | _ -> assert false ) |> Lambda.untype + | EDefault t -> ( + match eval_default_term top_uid exec_ctxt t with + | Ok value -> value |> Lambda.untype + | Error (true_pos, false_pos) -> + raise_default_conflict (Uid.get_ident top_uid, Uid.get_pos top_uid) true_pos false_pos ) in ((evaled_term, pos), typ) (* Evaluates a default term : see the formalization for an insight about this operation *) -let eval_default_term (exec_ctxt : exec_context) (term : Lambda.default_term) : +and eval_default_term (top_uid : Uid.t) (exec_ctxt : exec_context) (term : Lambda.default_term) : (Lambda.term, Pos.t list * Pos.t list) result = (* First filter out the term which justification are false *) let candidates : 'a IntMap.t = IntMap.filter (fun _ (cond, _) -> - match eval_term exec_ctxt cond |> Lambda.untype with EBool b -> b | _ -> assert false) + match eval_term top_uid exec_ctxt cond |> Lambda.untype with + | EBool b -> b + | _ -> assert false) term.defaults in (* Now filter out the terms that have a predecessor which justification is true *) @@ -133,7 +152,7 @@ let eval_default_term (exec_ctxt : exec_context) (term : Lambda.default_term) : match ISet.elements chosen_one with | [ x ] -> let _, cons = IntMap.find x term.defaults in - Ok (eval_term exec_ctxt cons) + Ok (eval_term top_uid exec_ctxt cons) | xs -> let true_pos = xs |> List.map (fun x -> IntMap.find x term.defaults |> fst |> Lambda.get_pos) @@ -164,7 +183,7 @@ let build_scope_schedule (ctxt : Context.context) (scope : Scope.scope) : G.t = -> var *) UidMap.iter (fun var_uid def -> - let fv = Lambda.default_term_fv def in + let fv = Lambda.term_fv def in UidSet.iter (fun uid -> if Context.belongs_to ctxt uid scope_uid then @@ -184,7 +203,7 @@ let build_scope_schedule (ctxt : Context.context) (scope : Scope.scope) : G.t = (fun sub_scope_uid defs -> UidMap.iter (fun _ def -> - let fv = Lambda.default_term_fv def in + let fv = Lambda.term_fv def in UidSet.iter (fun var_uid -> (* Process only uid from the current scope (not the subscope) *) @@ -196,41 +215,13 @@ let build_scope_schedule (ctxt : Context.context) (scope : Scope.scope) : G.t = scope.scope_sub_defs; g -let merge_var_redefs (subscope : Scope.scope) (redefs : Scope.definition UidMap.t) : Scope.scope = - { - subscope with - scope_defs = - UidMap.fold - (fun uid new_def sub_defs -> - match UidMap.find_opt uid sub_defs with - | None -> UidMap.add uid new_def sub_defs - | Some old_def -> - let def = Lambda.merge_default_terms old_def new_def in - UidMap.add uid def sub_defs) - redefs subscope.scope_defs; - } +let merge_var_redefs (_subscope : Scope.scope) (_redefs : Scope.definition UidMap.t) : Scope.scope = + assert false -let raise_default_conflict (id : string Pos.marked) (true_pos : Pos.t list) (false_pos : Pos.t list) - = - if List.length true_pos = 0 then - let justifications : (string option * Pos.t) list = - List.map (fun pos -> (Some "This justification is false:", pos)) false_pos - in - Errors.raise_multispanned_error - (Printf.sprintf "Default logic error for variable %s: no justification is true." - (Pos.unmark id)) - ( ( Some (Printf.sprintf "The error concerns this variable %s" (Pos.unmark id)), - Pos.get_position id ) - :: justifications ) - else - let justifications : (string option * Pos.t) list = - List.map (fun pos -> (Some "This justification is true:", pos)) true_pos - in - Errors.raise_multispanned_error - "Default logic conflict, multiple justifications are true but are not related by a precedence" - ( ( Some (Printf.sprintf "The conflict concerns this variable %s" (Pos.unmark id)), - Pos.get_position id ) - :: justifications ) +(*{ subscope with scope_defs = UidMap.fold (fun uid new_def sub_defs -> match UidMap.find_opt uid + sub_defs with | None -> UidMap.add uid new_def sub_defs | Some old_def -> let def = + Lambda.merge_default_terms old_def new_def in UidMap.add uid def sub_defs) redefs + subscope.scope_defs; }*) let rec execute_scope ?(exec_context = empty_exec_ctxt) (ctxt : Context.context) (prgm : Scope.program) (scope_prgm : Scope.scope) : exec_context = @@ -246,11 +237,8 @@ let rec execute_scope ?(exec_context = empty_exec_ctxt) (ctxt : Context.context) match Context.get_uid_sort ctxt uid with | IdScopeVar _ -> ( match UidMap.find_opt uid scope_prgm.scope_defs with - | Some def -> ( - match eval_default_term exec_context def with - | Ok value -> UidMap.add uid (Lambda.untype value) exec_context - | Error (true_pos, false_pos) -> - raise_default_conflict (Uid.get_ident uid, Uid.get_pos uid) true_pos false_pos ) + | Some def -> + UidMap.add uid (eval_term uid exec_context def |> Lambda.untype) exec_context | None -> Cli.error_print (Printf.sprintf "Variable %s is undefined in scope %s\n\n%s\n\n%s" From 78c7b653ecc61ab45c6aae6bc26b1d55a4e043a7 Mon Sep 17 00:00:00 2001 From: Nicolas Chataing Date: Sun, 9 Aug 2020 23:01:42 +0200 Subject: [PATCH 2/3] Correct function handling --- src/catala/ir/lambda.ml | 3 ++ src/catala/ir/scope.ml | 10 ++++ src/catala/translation/firstpass.ml | 84 +++++++++++++++++++++++++++-- tests/test_func/func.catala | 2 +- 4 files changed, 95 insertions(+), 4 deletions(-) diff --git a/src/catala/ir/lambda.ml b/src/catala/ir/lambda.ml index 3d40ee8c..99acf84e 100644 --- a/src/catala/ir/lambda.ml +++ b/src/catala/ir/lambda.ml @@ -58,6 +58,9 @@ let get_pos (((_, pos), _) : term) : Pos.t = pos let get_typ ((_, typ) : term) : typ = typ +let map_untype (f : untyped_term -> untyped_term) (((term, pos), typ) : term) : term = + ((f term, pos), typ) + let empty_default_term : default_term = { defaults = IntMap.empty; ordering = []; nb_defaults = 0 } let add_default (just : term) (cons : term) (term : default_term) = diff --git a/src/catala/ir/scope.ml b/src/catala/ir/scope.ml index 5bb57a79..cffa5adb 100644 --- a/src/catala/ir/scope.ml +++ b/src/catala/ir/scope.ml @@ -19,6 +19,16 @@ type binder = string Pos.marked type definition = Lambda.term +let empty_func_def (bind : Uid.t) (pos : Pos.t) (typ : Lambda.typ) : definition = + match typ with + | TArrow (t_arg, t_ret) -> + let body_term : Lambda.term = ((EDefault Lambda.empty_default_term, pos), t_ret) in + ((EFun ([ (bind, t_arg) ], body_term), pos), typ) + | _ -> assert false + +let empty_var_def (pos : Pos.t) (typ : Lambda.typ) : definition = + ((EDefault Lambda.empty_default_term, pos), typ) + type assertion = Lambda.term type variation_typ = Increasing | Decreasing diff --git a/src/catala/translation/firstpass.ml b/src/catala/translation/firstpass.ml index e6ac9f28..73885f0c 100644 --- a/src/catala/translation/firstpass.ml +++ b/src/catala/translation/firstpass.ml @@ -173,9 +173,87 @@ let process_default (ctxt : Context.context) (scope : Uid.t) (def : Lambda.defau Lambda.add_default condition body def (* Process a definition *) -let process_def (_precond : Lambda.term option) (_scope : Uid.t) (_ctxt : Context.context) - (_prgm : Scope.program) (_def : Ast.definition) : Scope.program = - assert false +let process_def (precond : Lambda.term option) (scope_uid : Uid.t) (ctxt : Context.context) + (prgm : Scope.program) (def : Ast.definition) : Scope.program = + let scope : Scope.scope = UidMap.find scope_uid prgm in + let default_pos = Pos.get_position def.definition_expr in + let is_func = match def.definition_parameter with Some _ -> true | None -> false in + let scope_updated = + match Pos.unmark def.definition_name with + | [ x ] -> + let x_uid = Context.get_var_uid scope_uid ctxt x in + let ctxt, arg_uid = Context.add_binding ctxt scope_uid x_uid def.definition_parameter in + let x_def = + match UidMap.find_opt x_uid scope.scope_defs with + | Some def -> def + | None -> + let typ = Context.get_uid_typ ctxt x_uid in + if is_func then Scope.empty_func_def (Option.get arg_uid) default_pos typ + else Scope.empty_var_def default_pos typ + in + let x_def = + Lambda.map_untype + (fun t -> + match t with + | EDefault default -> + EDefault + (process_default ctxt scope_uid default precond def.definition_condition + def.definition_expr None) + | EFun ([ bind ], term) -> ( + match arg_uid with + | Some bind' -> + if fst bind <> bind' then assert false + else + let term = + Lambda.map_untype + (fun t -> + match t with + | EDefault default -> + EDefault + (process_default ctxt scope_uid default precond + def.definition_condition def.definition_expr None) + | _ -> assert false) + term + in + EFun ([ bind ], term) + | None -> assert false ) + | _ -> assert false) + x_def + in + { scope with scope_defs = UidMap.add x_uid x_def scope.scope_defs } + | [ y; x ] -> + let y_uid, subscope_uid = Context.get_subscope_uid scope_uid ctxt y in + let x_uid = Context.get_var_uid subscope_uid ctxt x in + let ctxt, arg_uid = Context.add_binding ctxt subscope_uid x_uid def.definition_parameter in + let y_redefs = + match UidMap.find_opt y_uid scope.scope_sub_defs with + | Some redefs -> redefs + | None -> UidMap.empty + in + let x_redef = + match UidMap.find_opt x_uid y_redefs with + | Some redef -> redef + | None -> + let typ = Context.get_uid_typ ctxt x_uid in + if is_func then Scope.empty_func_def (Option.get arg_uid) default_pos typ + else Scope.empty_var_def default_pos typ + in + let x_redef = + Lambda.map_untype + (fun t -> + match t with + | EDefault default -> + EDefault + (process_default ctxt scope_uid default precond def.definition_condition + def.definition_expr (Some subscope_uid)) + | _ -> assert false) + x_redef + in + let y_redefs = UidMap.add x_uid x_redef y_redefs in + { scope with scope_sub_defs = UidMap.add y_uid y_redefs scope.scope_sub_defs } + | _ -> assert false + in + UidMap.add scope_uid scope_updated prgm (** Process a rule from the surface language *) let process_rule (precond : Lambda.term option) (scope : Uid.t) (ctxt : Context.context) diff --git a/tests/test_func/func.catala b/tests/test_func/func.catala index d27b2b1e..ddf59044 100644 --- a/tests/test_func/func.catala +++ b/tests/test_func/func.catala @@ -5,7 +5,7 @@ declaration scope S: context b content bool scope S: - def f of x ? (x >= x) |= x + x + def f of x ? x > x |= x + x def f of x ? not b |= x * x def b = false From 6acf49b6feafecf4c468dbba8f9c74f70a118daf Mon Sep 17 00:00:00 2001 From: Nicolas Chataing Date: Sun, 9 Aug 2020 23:50:00 +0200 Subject: [PATCH 3/3] Rewrite merge_var_redefs function --- src/catala/ir/lambda.ml | 4 ++++ src/catala/translation/interpreter.ml | 34 +++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/catala/ir/lambda.ml b/src/catala/ir/lambda.ml index 99acf84e..9b6f912a 100644 --- a/src/catala/ir/lambda.ml +++ b/src/catala/ir/lambda.ml @@ -61,6 +61,10 @@ let get_typ ((_, typ) : term) : typ = typ let map_untype (f : untyped_term -> untyped_term) (((term, pos), typ) : term) : term = ((f term, pos), typ) +let map_untype2 (f : untyped_term -> untyped_term -> untyped_term) (((t1, pos), typ) : term) + (((t2, _), _) : term) : term = + ((f t1 t2, pos), typ) + let empty_default_term : default_term = { defaults = IntMap.empty; ordering = []; nb_defaults = 0 } let add_default (just : term) (cons : term) (term : default_term) = diff --git a/src/catala/translation/interpreter.ml b/src/catala/translation/interpreter.ml index 67b9c30e..f1720dda 100644 --- a/src/catala/translation/interpreter.ml +++ b/src/catala/translation/interpreter.ml @@ -214,8 +214,38 @@ let build_scope_schedule (ctxt : Context.context) (scope : Scope.scope) : G.t = scope.scope_sub_defs; g -let merge_var_redefs (_subscope : Scope.scope) (_redefs : Scope.definition UidMap.t) : Scope.scope = - assert false +let merge_var_redefs (subscope : Scope.scope) (redefs : Scope.definition UidMap.t) : Scope.scope = + let merge_defaults : Lambda.term -> Lambda.term -> Lambda.term = + Lambda.map_untype2 (fun old_t new_t -> + match (old_t, new_t) with + | EDefault old_def, EDefault new_def -> + EDefault (Lambda.merge_default_terms old_def new_def) + | EFun ([ bind ], old_t), EFun (_, new_t) -> + let body = + Lambda.map_untype2 + (fun old_t new_t -> + match (old_t, new_t) with + | EDefault old_def, EDefault new_def -> + EDefault (Lambda.merge_default_terms old_def new_def) + | _ -> assert false) + old_t new_t + in + EFun ([ bind ], body) + | _ -> assert false) + in + + { + subscope with + scope_defs = + UidMap.fold + (fun uid new_def sub_defs -> + match UidMap.find_opt uid sub_defs with + | None -> UidMap.add uid new_def sub_defs + | Some old_def -> + let def = merge_defaults old_def new_def in + UidMap.add uid def sub_defs) + redefs subscope.scope_defs; + } (*{ subscope with scope_defs = UidMap.fold (fun uid new_def sub_defs -> match UidMap.find_opt uid sub_defs with | None -> UidMap.add uid new_def sub_defs | Some old_def -> let def =