diff --git a/compiler/dcalc/interpreter.ml b/compiler/dcalc/interpreter.ml index cc819f11..e918adaa 100644 --- a/compiler/dcalc/interpreter.ml +++ b/compiler/dcalc/interpreter.ml @@ -363,7 +363,6 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.expr Pos.mark the term was well-typed" (Pos.get_position e1)) | EDefault (exceptions, just, cons) -> ( - let exceptions_orig = exceptions in let exceptions = List.map (evaluate_expr ctx) exceptions in let empty_count = List.length (List.filter is_empty_error exceptions) in match List.length exceptions - empty_count with @@ -381,12 +380,12 @@ and evaluate_expr (ctx : Ast.decl_ctx) (e : A.expr Pos.marked) : A.expr Pos.mark | 1 -> List.find (fun sub -> not (is_empty_error sub)) exceptions | _ -> Errors.raise_multispanned_error - "There is a conflict between multiple exceptions for assigning the same variable." + "There is a conflict between multiple validd consequences for assigning the same \ + variable." (List.map - (fun (_, except) -> (Some "This justification is true:", Pos.get_position except)) - (List.filter - (fun (sub, _) -> not (is_empty_error sub)) - (List.map2 (fun x y -> (x, y)) exceptions exceptions_orig)))) + (fun except -> + (Some "This consequence has a valid justification:", Pos.get_position except)) + (List.filter (fun sub -> not (is_empty_error sub)) exceptions))) | EIfThenElse (cond, et, ef) -> ( match Pos.unmark (evaluate_expr ctx cond) with | ELit (LBool true) -> evaluate_expr ctx et diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 98bb54f3..9e10d207 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -240,11 +240,25 @@ let build_exceptions_graph (def : Ast.rule Ast.RuleMap.t) (def_info : Ast.ScopeD (Ast.RuleSet.to_seq rule_set2)))) all_rule_sets_pointed_to_by_exceptions) all_rule_sets_pointed_to_by_exceptions; + (* Then we add the exception graph vertices by taking all those sets of rules pointed to by + exceptions, and adding the remaining rules not pointed as separate singleton set vertices *) let g = List.fold_left (fun g rule_set -> ExceptionsDependencies.add_vertex g rule_set) ExceptionsDependencies.empty all_rule_sets_pointed_to_by_exceptions in + let g = + Ast.RuleMap.fold + (fun (rule_name : Ast.RuleName.t) _ g -> + if + List.exists + (fun rule_set_pointed_to_by_exceptions -> + Ast.RuleSet.mem rule_name rule_set_pointed_to_by_exceptions) + all_rule_sets_pointed_to_by_exceptions + then g + else ExceptionsDependencies.add_vertex g (Ast.RuleSet.singleton rule_name)) + def g + in (* then we add the edges *) let g = Ast.RuleMap.fold @@ -252,7 +266,8 @@ let build_exceptions_graph (def : Ast.rule Ast.RuleMap.t) (def_info : Ast.ScopeD (* Right now, exceptions can only consist of one rule, we may want to relax that constraint later in the development of Catala. *) let exception_to_ruleset, pos = rule.Ast.rule_exception_to_rules in - if ExceptionsDependencies.mem_vertex g exception_to_ruleset then + if Ast.RuleSet.is_empty exception_to_ruleset then g (* we don't add an edge*) + else if ExceptionsDependencies.mem_vertex g exception_to_ruleset then if exception_to_ruleset = Ast.RuleSet.singleton rule_name then Errors.raise_spanned_error "Cannot define rule as an exception to itself" pos else diff --git a/compiler/desugared/desugared_to_scope.ml b/compiler/desugared/desugared_to_scope.ml index bce0234b..44cae5e9 100644 --- a/compiler/desugared/desugared_to_scope.ml +++ b/compiler/desugared/desugared_to_scope.ml @@ -18,12 +18,30 @@ open Utils (** {1 Rule tree construction} *) -type rule_tree = Leaf of Ast.rule | Node of rule_tree list * Ast.rule +(** Intermediate representation for the exception tree of rules for a particular scope definition. *) +type rule_tree = + | Leaf of Ast.rule list (** Rules defining a base case piecewise. List is non-empty. *) + | Node of rule_tree list * Ast.rule list + (** A list of exceptions to a non-empty list of rules defining a base case piecewise. *) (** Transforms a flat list of rules into a tree, taking into account the priorities declared between rules *) let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) : rule_tree list = let exc_graph = Dependency.build_exceptions_graph def def_info in + Cli.debug_print + (Format.asprintf "For definition %a, the exception vertices are: %a" Ast.ScopeDef.format_t + def_info + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "; ") + (fun fmt ruleset -> + Format.fprintf fmt "[%a]" + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "; ") + (fun fmt (rule : Ast.RuleName.t) -> + Format.fprintf fmt "%s" + (Pos.to_string_short (Pos.get_position (Ast.RuleName.get_info rule))))) + (List.of_seq (Ast.RuleSet.to_seq ruleset)))) + (Dependency.ExceptionsDependencies.fold_vertex (fun v acc -> v :: acc) exc_graph [])); Dependency.check_for_exception_cycle exc_graph; (* we start by the base cases: they are the vertices which have no successors *) let base_cases = @@ -33,11 +51,14 @@ let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) : else base_cases) exc_graph [] in - let rec build_tree (base_case : Ast.RuleName.t) : rule_tree = - let exceptions = Dependency.ExceptionsDependencies.pred exc_graph base_case in + let rec build_tree (base_cases : Ast.RuleSet.t) : rule_tree = + let exceptions = Dependency.ExceptionsDependencies.pred exc_graph base_cases in + let base_case_as_rule_list = + List.map (fun r -> Ast.RuleMap.find r def) (List.of_seq (Ast.RuleSet.to_seq base_cases)) + in match exceptions with - | [] -> Leaf (Ast.RuleMap.find base_case def) - | _ -> Node (List.map build_tree exceptions, Ast.RuleMap.find base_case def) + | [] -> Leaf base_case_as_rule_list + | _ -> Node (List.map build_tree exceptions, base_case_as_rule_list) in List.map build_tree base_cases @@ -47,14 +68,14 @@ let def_map_to_tree (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) : let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t) (is_func : Scopelang.Ast.Var.t option) (tree : rule_tree) : Scopelang.Ast.expr Pos.marked Bindlib.box = - let exceptions, rule = + let exceptions, base_rules = match tree with Leaf r -> ([], r) | Node (exceptions, r) -> (exceptions, r) in (* because each rule has its own variable parameter and we want to convert the whole rule tree into a function, we need to perform some alpha-renaming of all the expressions *) - let substitute_parameter (e : Scopelang.Ast.expr Pos.marked Bindlib.box) : + let substitute_parameter (e : Scopelang.Ast.expr Pos.marked Bindlib.box) (rule : Ast.rule) : Scopelang.Ast.expr Pos.marked Bindlib.box = - match (is_func, rule.rule_parameter) with + match (is_func, rule.Ast.rule_parameter) with | Some new_param, Some (old_param, _) -> let binder = Bindlib.bind_var old_param e in Bindlib.box_apply2 @@ -64,18 +85,39 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t) | _ -> assert false (* should not happen *) in - let just = substitute_parameter rule.Ast.rule_just in - let cons = substitute_parameter rule.Ast.rule_cons in + let base_just_list = + List.map (fun rule -> substitute_parameter rule.Ast.rule_just rule) base_rules + in + let base_cons_list = + List.map (fun rule -> substitute_parameter rule.Ast.rule_cons rule) base_rules + in + let default_containing_base_cases = + Bindlib.box_apply2 + (fun base_just_list base_cons_list -> + ( Scopelang.Ast.EDefault + ( List.map2 + (fun base_just base_cons -> + (Scopelang.Ast.EDefault ([], base_just, base_cons), Pos.get_position base_just)) + base_just_list base_cons_list, + (Scopelang.Ast.ELit (Dcalc.Ast.LBool false), def_pos), + (Scopelang.Ast.ELit Dcalc.Ast.LEmptyError, def_pos) ), + def_pos )) + (Bindlib.box_list base_just_list) (Bindlib.box_list base_cons_list) + in let exceptions = Bindlib.box_list (List.map (rule_tree_to_expr ~toplevel:false def_pos is_func) exceptions) in let default = - Bindlib.box_apply3 - (fun exceptions just cons -> - (Scopelang.Ast.EDefault (exceptions, just, cons), Pos.get_position just)) - exceptions just cons + Bindlib.box_apply2 + (fun exceptions default_containing_base_cases -> + ( Scopelang.Ast.EDefault + ( exceptions, + (Scopelang.Ast.ELit (Dcalc.Ast.LBool true), def_pos), + default_containing_base_cases ), + def_pos )) + exceptions default_containing_base_cases in - match (is_func, rule.rule_parameter) with + match (is_func, (List.hd base_rules).Ast.rule_parameter) with | None, None -> default | Some new_param, Some (_, typ) -> if toplevel then @@ -137,8 +179,8 @@ let translate_def (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t) (match top_list with | [] -> (* In this case, there are no rules to define the expression *) - Leaf top_value - | _ -> Node (top_list, top_value))) + Leaf [ top_value ] + | _ -> Node (top_list, [ top_value ]))) (** Translates a scope *) let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =