mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Refactoring done except Desugared_to_scope.def_map_to_tree [skip ci]
This commit is contained in:
parent
f6825668dd
commit
9733f39653
@ -77,7 +77,7 @@ type rule = {
|
||||
rule_just : Scopelang.Ast.expr Pos.marked Bindlib.box;
|
||||
rule_cons : Scopelang.Ast.expr Pos.marked Bindlib.box;
|
||||
rule_parameter : (Scopelang.Ast.Var.t * Scopelang.Ast.typ Pos.marked) option;
|
||||
rule_exception_to_rules : Pos.t RuleMap.t;
|
||||
rule_exception_to_rules : RuleSet.t Pos.marked;
|
||||
}
|
||||
|
||||
let empty_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule =
|
||||
@ -88,7 +88,7 @@ let empty_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked opti
|
||||
(match have_parameter with
|
||||
| Some typ -> Some (Scopelang.Ast.Var.make ("dummy", pos), typ)
|
||||
| None -> None);
|
||||
rule_exception_to_rules = RuleMap.empty;
|
||||
rule_exception_to_rules = (RuleSet.empty, pos);
|
||||
rule_id = RuleName.fresh ("empty", pos);
|
||||
}
|
||||
|
||||
@ -100,7 +100,7 @@ let always_false_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.mark
|
||||
(match have_parameter with
|
||||
| Some typ -> Some (Scopelang.Ast.Var.make ("dummy", pos), typ)
|
||||
| None -> None);
|
||||
rule_exception_to_rules = RuleMap.empty;
|
||||
rule_exception_to_rules = (RuleSet.empty, pos);
|
||||
rule_id = RuleName.fresh ("always_false", pos);
|
||||
}
|
||||
|
||||
|
@ -58,8 +58,7 @@ type rule = {
|
||||
rule_just : Scopelang.Ast.expr Pos.marked Bindlib.box;
|
||||
rule_cons : Scopelang.Ast.expr Pos.marked Bindlib.box;
|
||||
rule_parameter : (Scopelang.Ast.Var.t * Scopelang.Ast.typ Pos.marked) option;
|
||||
rule_exception_to_rules : Pos.t RuleMap.t;
|
||||
(** To each parent exception rule is attached the position of the exception label*)
|
||||
rule_exception_to_rules : RuleSet.t Pos.marked;
|
||||
}
|
||||
|
||||
val empty_rule : Pos.t -> Scopelang.Ast.typ Pos.marked option -> rule
|
||||
|
@ -128,7 +128,8 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
|
||||
in
|
||||
let g =
|
||||
Ast.ScopeDefMap.fold
|
||||
(fun def_key (def, _, _) g ->
|
||||
(fun def_key scope_def g ->
|
||||
let def = scope_def.Ast.scope_def_rules in
|
||||
let fv = Ast.free_variables def in
|
||||
Ast.ScopeDefMap.fold
|
||||
(fun fv_def fv_def_pos g ->
|
||||
@ -186,7 +187,9 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
|
||||
(** {2 Graph declaration} *)
|
||||
|
||||
module ExceptionVertex = struct
|
||||
include Ast.RuleName
|
||||
include Ast.RuleSet
|
||||
|
||||
let hash (x : t) : int = Ast.RuleSet.fold (fun r acc -> Int.logxor (Ast.RuleName.hash r) acc) x 0
|
||||
|
||||
let equal x y = compare x y = 0
|
||||
end
|
||||
@ -202,32 +205,69 @@ module ExceptionsSCC = Graph.Components.Make (ExceptionsDependencies)
|
||||
|
||||
let build_exceptions_graph (def : Ast.rule Ast.RuleMap.t) (def_info : Ast.ScopeDef.t) :
|
||||
ExceptionsDependencies.t =
|
||||
(* first we add the vertices *)
|
||||
let g =
|
||||
(* first we collect all the rule sets referred by exceptions *)
|
||||
let all_rule_sets_pointed_to_by_exceptions : Ast.RuleSet.t list =
|
||||
Ast.RuleMap.fold
|
||||
(fun rule_name _ g -> ExceptionsDependencies.add_vertex g rule_name)
|
||||
def ExceptionsDependencies.empty
|
||||
(fun _rule_name rule acc ->
|
||||
if Ast.RuleSet.is_empty (Pos.unmark rule.Ast.rule_exception_to_rules) then acc
|
||||
else Pos.unmark rule.Ast.rule_exception_to_rules :: acc)
|
||||
def []
|
||||
in
|
||||
(* we make sure these sets are either disjoint or equal ; should be a syntactic invariant since
|
||||
you currently can't assign two labels to a single rule but an extra check is valuable since
|
||||
this is a required invariant for the graph to be sound *)
|
||||
List.iter
|
||||
(fun rule_set1 ->
|
||||
List.iter
|
||||
(fun rule_set2 ->
|
||||
if Ast.RuleSet.equal rule_set1 rule_set2 then ()
|
||||
else if Ast.RuleSet.disjoint rule_set1 rule_set2 then ()
|
||||
else
|
||||
Errors.raise_multispanned_error
|
||||
"Definitions or rules grouped by different labels overlap, whereas these groups \
|
||||
shoule be disjoint"
|
||||
(List.of_seq
|
||||
(Seq.map
|
||||
(fun rule ->
|
||||
( Some "Rule or definition from the first group:",
|
||||
Pos.get_position (Ast.RuleName.get_info rule) ))
|
||||
(Ast.RuleSet.to_seq rule_set1))
|
||||
@ List.of_seq
|
||||
(Seq.map
|
||||
(fun rule ->
|
||||
( Some "Rule or definition from the second group:",
|
||||
Pos.get_position (Ast.RuleName.get_info rule) ))
|
||||
(Ast.RuleSet.to_seq rule_set2))))
|
||||
all_rule_sets_pointed_to_by_exceptions)
|
||||
all_rule_sets_pointed_to_by_exceptions;
|
||||
let g =
|
||||
List.fold_left
|
||||
(fun g rule_set -> ExceptionsDependencies.add_vertex g rule_set)
|
||||
ExceptionsDependencies.empty all_rule_sets_pointed_to_by_exceptions
|
||||
in
|
||||
(* then we add the edges *)
|
||||
let g =
|
||||
Ast.RuleMap.fold
|
||||
(fun rule_name rule g ->
|
||||
match rule.Ast.exception_to_rule with
|
||||
| None -> g
|
||||
| Some (exc_r, pos) ->
|
||||
if ExceptionsDependencies.mem_vertex g exc_r then
|
||||
if exc_r = rule_name then
|
||||
Errors.raise_spanned_error "Cannot define rule as an exception to itself" pos
|
||||
else
|
||||
let edge = ExceptionsDependencies.E.create rule_name pos exc_r in
|
||||
ExceptionsDependencies.add_edge_e g edge
|
||||
else
|
||||
Errors.raise_spanned_error
|
||||
(Format.asprintf
|
||||
"This rule has been declared as an exception to an incorrect label: this label \
|
||||
is not attached to a definition of \"%a\""
|
||||
Ast.ScopeDef.format_t def_info)
|
||||
pos)
|
||||
(* Right now, exceptions can only consist of one rule, we may want to relax that constraint
|
||||
later in the development of Catala. *)
|
||||
let exception_to_ruleset, pos = rule.Ast.rule_exception_to_rules in
|
||||
if ExceptionsDependencies.mem_vertex g exception_to_ruleset then
|
||||
if exception_to_ruleset = Ast.RuleSet.singleton rule_name then
|
||||
Errors.raise_spanned_error "Cannot define rule as an exception to itself" pos
|
||||
else
|
||||
let edge =
|
||||
ExceptionsDependencies.E.create (Ast.RuleSet.singleton rule_name) pos
|
||||
exception_to_ruleset
|
||||
in
|
||||
ExceptionsDependencies.add_edge_e g edge
|
||||
else
|
||||
Errors.raise_spanned_error
|
||||
(Format.asprintf
|
||||
"This rule has been declared as an exception to an incorrect label: this label is \
|
||||
not attached to a definition of \"%a\""
|
||||
Ast.ScopeDef.format_t def_info)
|
||||
pos)
|
||||
def g
|
||||
in
|
||||
g
|
||||
@ -242,11 +282,12 @@ let check_for_exception_cycle (g : ExceptionsDependencies.t) : unit =
|
||||
(Format.asprintf "Cyclic dependency detected between exceptions!")
|
||||
(List.flatten
|
||||
(List.map
|
||||
(fun (v : Ast.RuleName.t) ->
|
||||
(fun (vs : Ast.RuleSet.t) ->
|
||||
let v = Ast.RuleSet.choose vs in
|
||||
let var_str, var_info =
|
||||
(Format.asprintf "%a" Ast.RuleName.format_t v, Ast.RuleName.get_info v)
|
||||
in
|
||||
let succs = ExceptionsDependencies.succ_e g v in
|
||||
let succs = ExceptionsDependencies.succ_e g vs in
|
||||
let _, edge_pos, _ = List.find (fun (_, _, succ) -> List.mem succ scc) succs in
|
||||
[
|
||||
( Some
|
||||
|
@ -61,7 +61,7 @@ val build_scope_dependencies : Ast.scope -> ScopeDependencies.t
|
||||
|
||||
(** {1 Exceptions dependency graph} *)
|
||||
|
||||
module ExceptionsDependencies : Graph.Sig.P with type V.t = Ast.RuleName.t and type E.label = Edge.t
|
||||
module ExceptionsDependencies : Graph.Sig.P with type V.t = Ast.RuleSet.t and type E.label = Edge.t
|
||||
|
||||
val build_exceptions_graph : Ast.rule Ast.RuleMap.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t
|
||||
|
||||
|
@ -54,7 +54,7 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
|
||||
into a function, we need to perform some alpha-renaming of all the expressions *)
|
||||
let substitute_parameter (e : Scopelang.Ast.expr Pos.marked Bindlib.box) :
|
||||
Scopelang.Ast.expr Pos.marked Bindlib.box =
|
||||
match (is_func, rule.parameter) with
|
||||
match (is_func, rule.rule_parameter) with
|
||||
| Some new_param, Some (old_param, _) ->
|
||||
let binder = Bindlib.bind_var old_param e in
|
||||
Bindlib.box_apply2
|
||||
@ -64,8 +64,8 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
|
||||
| _ -> assert false
|
||||
(* should not happen *)
|
||||
in
|
||||
let just = substitute_parameter rule.Ast.just in
|
||||
let cons = substitute_parameter rule.Ast.cons in
|
||||
let just = substitute_parameter rule.Ast.rule_just in
|
||||
let cons = substitute_parameter rule.Ast.rule_cons in
|
||||
let exceptions =
|
||||
Bindlib.box_list (List.map (rule_tree_to_expr ~toplevel:false def_pos is_func) exceptions)
|
||||
in
|
||||
@ -75,7 +75,7 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
|
||||
(Scopelang.Ast.EDefault (exceptions, just, cons), Pos.get_position just))
|
||||
exceptions just cons
|
||||
in
|
||||
match (is_func, rule.parameter) with
|
||||
match (is_func, rule.rule_parameter) with
|
||||
| None, None -> default
|
||||
| Some new_param, Some (_, typ) ->
|
||||
if toplevel then
|
||||
@ -98,7 +98,7 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
|
||||
let translate_def (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t)
|
||||
(typ : Scopelang.Ast.typ Pos.marked) (is_cond : bool) : Scopelang.Ast.expr Pos.marked =
|
||||
(* Here, we have to transform this list of rules into a default tree. *)
|
||||
let is_func _ (r : Ast.rule) : bool = Option.is_some r.Ast.parameter in
|
||||
let is_func _ (r : Ast.rule) : bool = Option.is_some r.Ast.rule_parameter in
|
||||
let all_rules_func = Ast.RuleMap.for_all is_func def in
|
||||
let all_rules_not_func = Ast.RuleMap.for_all (fun n r -> not (is_func n r)) def in
|
||||
let is_def_func : Scopelang.Ast.typ Pos.marked option =
|
||||
@ -117,12 +117,13 @@ let translate_def (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t)
|
||||
"some definitions of the same variable are functions while others aren't"
|
||||
(List.map
|
||||
(fun (_, r) ->
|
||||
(Some "This definition is a function:", Pos.get_position (Bindlib.unbox r.Ast.cons)))
|
||||
( Some "This definition is a function:",
|
||||
Pos.get_position (Bindlib.unbox r.Ast.rule_cons) ))
|
||||
(Ast.RuleMap.bindings (Ast.RuleMap.filter is_func def))
|
||||
@ List.map
|
||||
(fun (_, r) ->
|
||||
( Some "This definition is not a function:",
|
||||
Pos.get_position (Bindlib.unbox r.Ast.cons) ))
|
||||
Pos.get_position (Bindlib.unbox r.Ast.rule_cons) ))
|
||||
(Ast.RuleMap.bindings (Ast.RuleMap.filter (fun n r -> not (is_func n r)) def)))
|
||||
in
|
||||
let top_list = def_map_to_tree def_info def in
|
||||
@ -150,9 +151,10 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
(fun vertex ->
|
||||
match vertex with
|
||||
| Dependency.Vertex.Var (var : Scopelang.Ast.ScopeVar.t) ->
|
||||
let var_def, var_typ, is_cond =
|
||||
Ast.ScopeDefMap.find (Ast.ScopeDef.Var var) scope.scope_defs
|
||||
in
|
||||
let scope_def = Ast.ScopeDefMap.find (Ast.ScopeDef.Var var) scope.scope_defs in
|
||||
let var_def = scope_def.scope_def_rules in
|
||||
let var_typ = scope_def.scope_def_typ in
|
||||
let is_cond = scope_def.scope_def_is_condition in
|
||||
let expr_def = translate_def (Ast.ScopeDef.Var var) var_def var_typ is_cond in
|
||||
[
|
||||
Scopelang.Ast.Definition
|
||||
@ -170,7 +172,10 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
in
|
||||
let sub_scope_vars_redefs =
|
||||
Ast.ScopeDefMap.mapi
|
||||
(fun def_key (def, def_typ, is_cond) ->
|
||||
(fun def_key scope_def ->
|
||||
let def = scope_def.Ast.scope_def_rules in
|
||||
let def_typ = scope_def.scope_def_typ in
|
||||
let is_cond = scope_def.scope_def_is_condition in
|
||||
match def_key with
|
||||
| Ast.ScopeDef.Var _ -> assert false (* should not happen *)
|
||||
| Ast.ScopeDef.SubScopeVar (_, sub_scope_var) ->
|
||||
@ -211,7 +216,7 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
|
||||
let scope_sig =
|
||||
Scopelang.Ast.ScopeVarSet.fold
|
||||
(fun var acc ->
|
||||
let _, typ, _ = Ast.ScopeDefMap.find (Ast.ScopeDef.Var var) scope.scope_defs in
|
||||
let typ = (Ast.ScopeDefMap.find (Ast.ScopeDef.Var var) scope.scope_defs).scope_def_typ in
|
||||
Scopelang.Ast.ScopeVarMap.add var typ acc)
|
||||
scope.scope_vars Scopelang.Ast.ScopeVarMap.empty
|
||||
in
|
||||
|
@ -881,8 +881,9 @@ let process_default (ctxt : Name_resolution.context) (scope : Scopelang.Ast.Scop
|
||||
(def_key : Desugared.Ast.ScopeDef.t Pos.marked) (rule_id : Desugared.Ast.RuleName.t)
|
||||
(param_uid : Scopelang.Ast.Var.t Pos.marked option)
|
||||
(precond : Scopelang.Ast.expr Pos.marked Bindlib.box option)
|
||||
(exception_to_rules : Pos.t Desugared.Ast.RuleMap.t) (just : Ast.expression Pos.marked option)
|
||||
(cons : Ast.expression Pos.marked) : Desugared.Ast.rule =
|
||||
(exception_to_rules : Desugared.Ast.RuleSet.t Pos.marked)
|
||||
(just : Ast.expression Pos.marked option) (cons : Ast.expression Pos.marked) :
|
||||
Desugared.Ast.rule =
|
||||
let just = match just with Some just -> Some (translate_expr scope ctxt just) | None -> None in
|
||||
let just = merge_conditions precond just (Pos.get_position def_key) in
|
||||
let cons = translate_expr scope ctxt cons in
|
||||
@ -940,22 +941,19 @@ let process_def (precond : Scopelang.Ast.expr Pos.marked Bindlib.box option)
|
||||
let rule_name = def.definition_id in
|
||||
let parent_rules =
|
||||
match def.Ast.definition_exception_to with
|
||||
| NotAnException -> Desugared.Ast.RuleMap.empty
|
||||
| NotAnException -> (Desugared.Ast.RuleSet.empty, Pos.get_position def.Ast.definition_name)
|
||||
| UnlabeledException -> (
|
||||
match scope_def_ctxt.default_exception_rulename with
|
||||
(* This should have been caught previously by check_unlabeled_exception *)
|
||||
| None | Some (Name_resolution.Ambiguous _) -> assert false
|
||||
| Some (Name_resolution.Unique name) ->
|
||||
Desugared.Ast.RuleMap.singleton name (Pos.get_position def.Ast.definition_name))
|
||||
(Desugared.Ast.RuleSet.singleton name, Pos.get_position def.Ast.definition_name))
|
||||
| ExceptionToLabel label -> (
|
||||
try
|
||||
Desugared.Ast.RuleSet.fold
|
||||
(fun parent_rule (acc : Pos.t Desugared.Ast.RuleMap.t) ->
|
||||
Desugared.Ast.RuleMap.add parent_rule (Pos.get_position label) acc)
|
||||
(Desugared.Ast.LabelMap.find
|
||||
(Desugared.Ast.IdentMap.find (Pos.unmark label) scope_def_ctxt.label_idmap)
|
||||
scope_def.scope_def_label_groups)
|
||||
Desugared.Ast.RuleMap.empty
|
||||
( Desugared.Ast.LabelMap.find
|
||||
(Desugared.Ast.IdentMap.find (Pos.unmark label) scope_def_ctxt.label_idmap)
|
||||
scope_def.scope_def_label_groups,
|
||||
Pos.get_position def.Ast.definition_name )
|
||||
with Not_found ->
|
||||
Errors.raise_spanned_error
|
||||
(Format.asprintf "Unknown label for the scope variable %a: \"%s\""
|
||||
|
Loading…
Reference in New Issue
Block a user