Refactoring done except Desugared_to_scope.def_map_to_tree [skip ci]

This commit is contained in:
Denis Merigoux 2022-01-05 09:14:43 +01:00
parent f6825668dd
commit 9733f39653
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
6 changed files with 96 additions and 53 deletions

View File

@ -77,7 +77,7 @@ type rule = {
rule_just : Scopelang.Ast.expr Pos.marked Bindlib.box;
rule_cons : Scopelang.Ast.expr Pos.marked Bindlib.box;
rule_parameter : (Scopelang.Ast.Var.t * Scopelang.Ast.typ Pos.marked) option;
rule_exception_to_rules : Pos.t RuleMap.t;
rule_exception_to_rules : RuleSet.t Pos.marked;
}
let empty_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked option) : rule =
@ -88,7 +88,7 @@ let empty_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.marked opti
(match have_parameter with
| Some typ -> Some (Scopelang.Ast.Var.make ("dummy", pos), typ)
| None -> None);
rule_exception_to_rules = RuleMap.empty;
rule_exception_to_rules = (RuleSet.empty, pos);
rule_id = RuleName.fresh ("empty", pos);
}
@ -100,7 +100,7 @@ let always_false_rule (pos : Pos.t) (have_parameter : Scopelang.Ast.typ Pos.mark
(match have_parameter with
| Some typ -> Some (Scopelang.Ast.Var.make ("dummy", pos), typ)
| None -> None);
rule_exception_to_rules = RuleMap.empty;
rule_exception_to_rules = (RuleSet.empty, pos);
rule_id = RuleName.fresh ("always_false", pos);
}

View File

@ -58,8 +58,7 @@ type rule = {
rule_just : Scopelang.Ast.expr Pos.marked Bindlib.box;
rule_cons : Scopelang.Ast.expr Pos.marked Bindlib.box;
rule_parameter : (Scopelang.Ast.Var.t * Scopelang.Ast.typ Pos.marked) option;
rule_exception_to_rules : Pos.t RuleMap.t;
(** To each parent exception rule is attached the position of the exception label*)
rule_exception_to_rules : RuleSet.t Pos.marked;
}
val empty_rule : Pos.t -> Scopelang.Ast.typ Pos.marked option -> rule

View File

@ -128,7 +128,8 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
in
let g =
Ast.ScopeDefMap.fold
(fun def_key (def, _, _) g ->
(fun def_key scope_def g ->
let def = scope_def.Ast.scope_def_rules in
let fv = Ast.free_variables def in
Ast.ScopeDefMap.fold
(fun fv_def fv_def_pos g ->
@ -186,7 +187,9 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t =
(** {2 Graph declaration} *)
module ExceptionVertex = struct
include Ast.RuleName
include Ast.RuleSet
let hash (x : t) : int = Ast.RuleSet.fold (fun r acc -> Int.logxor (Ast.RuleName.hash r) acc) x 0
let equal x y = compare x y = 0
end
@ -202,32 +205,69 @@ module ExceptionsSCC = Graph.Components.Make (ExceptionsDependencies)
let build_exceptions_graph (def : Ast.rule Ast.RuleMap.t) (def_info : Ast.ScopeDef.t) :
ExceptionsDependencies.t =
(* first we add the vertices *)
let g =
(* first we collect all the rule sets referred by exceptions *)
let all_rule_sets_pointed_to_by_exceptions : Ast.RuleSet.t list =
Ast.RuleMap.fold
(fun rule_name _ g -> ExceptionsDependencies.add_vertex g rule_name)
def ExceptionsDependencies.empty
(fun _rule_name rule acc ->
if Ast.RuleSet.is_empty (Pos.unmark rule.Ast.rule_exception_to_rules) then acc
else Pos.unmark rule.Ast.rule_exception_to_rules :: acc)
def []
in
(* we make sure these sets are either disjoint or equal ; should be a syntactic invariant since
you currently can't assign two labels to a single rule but an extra check is valuable since
this is a required invariant for the graph to be sound *)
List.iter
(fun rule_set1 ->
List.iter
(fun rule_set2 ->
if Ast.RuleSet.equal rule_set1 rule_set2 then ()
else if Ast.RuleSet.disjoint rule_set1 rule_set2 then ()
else
Errors.raise_multispanned_error
"Definitions or rules grouped by different labels overlap, whereas these groups \
shoule be disjoint"
(List.of_seq
(Seq.map
(fun rule ->
( Some "Rule or definition from the first group:",
Pos.get_position (Ast.RuleName.get_info rule) ))
(Ast.RuleSet.to_seq rule_set1))
@ List.of_seq
(Seq.map
(fun rule ->
( Some "Rule or definition from the second group:",
Pos.get_position (Ast.RuleName.get_info rule) ))
(Ast.RuleSet.to_seq rule_set2))))
all_rule_sets_pointed_to_by_exceptions)
all_rule_sets_pointed_to_by_exceptions;
let g =
List.fold_left
(fun g rule_set -> ExceptionsDependencies.add_vertex g rule_set)
ExceptionsDependencies.empty all_rule_sets_pointed_to_by_exceptions
in
(* then we add the edges *)
let g =
Ast.RuleMap.fold
(fun rule_name rule g ->
match rule.Ast.exception_to_rule with
| None -> g
| Some (exc_r, pos) ->
if ExceptionsDependencies.mem_vertex g exc_r then
if exc_r = rule_name then
Errors.raise_spanned_error "Cannot define rule as an exception to itself" pos
else
let edge = ExceptionsDependencies.E.create rule_name pos exc_r in
ExceptionsDependencies.add_edge_e g edge
else
Errors.raise_spanned_error
(Format.asprintf
"This rule has been declared as an exception to an incorrect label: this label \
is not attached to a definition of \"%a\""
Ast.ScopeDef.format_t def_info)
pos)
(* Right now, exceptions can only consist of one rule, we may want to relax that constraint
later in the development of Catala. *)
let exception_to_ruleset, pos = rule.Ast.rule_exception_to_rules in
if ExceptionsDependencies.mem_vertex g exception_to_ruleset then
if exception_to_ruleset = Ast.RuleSet.singleton rule_name then
Errors.raise_spanned_error "Cannot define rule as an exception to itself" pos
else
let edge =
ExceptionsDependencies.E.create (Ast.RuleSet.singleton rule_name) pos
exception_to_ruleset
in
ExceptionsDependencies.add_edge_e g edge
else
Errors.raise_spanned_error
(Format.asprintf
"This rule has been declared as an exception to an incorrect label: this label is \
not attached to a definition of \"%a\""
Ast.ScopeDef.format_t def_info)
pos)
def g
in
g
@ -242,11 +282,12 @@ let check_for_exception_cycle (g : ExceptionsDependencies.t) : unit =
(Format.asprintf "Cyclic dependency detected between exceptions!")
(List.flatten
(List.map
(fun (v : Ast.RuleName.t) ->
(fun (vs : Ast.RuleSet.t) ->
let v = Ast.RuleSet.choose vs in
let var_str, var_info =
(Format.asprintf "%a" Ast.RuleName.format_t v, Ast.RuleName.get_info v)
in
let succs = ExceptionsDependencies.succ_e g v in
let succs = ExceptionsDependencies.succ_e g vs in
let _, edge_pos, _ = List.find (fun (_, _, succ) -> List.mem succ scc) succs in
[
( Some

View File

@ -61,7 +61,7 @@ val build_scope_dependencies : Ast.scope -> ScopeDependencies.t
(** {1 Exceptions dependency graph} *)
module ExceptionsDependencies : Graph.Sig.P with type V.t = Ast.RuleName.t and type E.label = Edge.t
module ExceptionsDependencies : Graph.Sig.P with type V.t = Ast.RuleSet.t and type E.label = Edge.t
val build_exceptions_graph : Ast.rule Ast.RuleMap.t -> Ast.ScopeDef.t -> ExceptionsDependencies.t

View File

@ -54,7 +54,7 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
into a function, we need to perform some alpha-renaming of all the expressions *)
let substitute_parameter (e : Scopelang.Ast.expr Pos.marked Bindlib.box) :
Scopelang.Ast.expr Pos.marked Bindlib.box =
match (is_func, rule.parameter) with
match (is_func, rule.rule_parameter) with
| Some new_param, Some (old_param, _) ->
let binder = Bindlib.bind_var old_param e in
Bindlib.box_apply2
@ -64,8 +64,8 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
| _ -> assert false
(* should not happen *)
in
let just = substitute_parameter rule.Ast.just in
let cons = substitute_parameter rule.Ast.cons in
let just = substitute_parameter rule.Ast.rule_just in
let cons = substitute_parameter rule.Ast.rule_cons in
let exceptions =
Bindlib.box_list (List.map (rule_tree_to_expr ~toplevel:false def_pos is_func) exceptions)
in
@ -75,7 +75,7 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
(Scopelang.Ast.EDefault (exceptions, just, cons), Pos.get_position just))
exceptions just cons
in
match (is_func, rule.parameter) with
match (is_func, rule.rule_parameter) with
| None, None -> default
| Some new_param, Some (_, typ) ->
if toplevel then
@ -98,7 +98,7 @@ let rec rule_tree_to_expr ~(toplevel : bool) (def_pos : Pos.t)
let translate_def (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t)
(typ : Scopelang.Ast.typ Pos.marked) (is_cond : bool) : Scopelang.Ast.expr Pos.marked =
(* Here, we have to transform this list of rules into a default tree. *)
let is_func _ (r : Ast.rule) : bool = Option.is_some r.Ast.parameter in
let is_func _ (r : Ast.rule) : bool = Option.is_some r.Ast.rule_parameter in
let all_rules_func = Ast.RuleMap.for_all is_func def in
let all_rules_not_func = Ast.RuleMap.for_all (fun n r -> not (is_func n r)) def in
let is_def_func : Scopelang.Ast.typ Pos.marked option =
@ -117,12 +117,13 @@ let translate_def (def_info : Ast.ScopeDef.t) (def : Ast.rule Ast.RuleMap.t)
"some definitions of the same variable are functions while others aren't"
(List.map
(fun (_, r) ->
(Some "This definition is a function:", Pos.get_position (Bindlib.unbox r.Ast.cons)))
( Some "This definition is a function:",
Pos.get_position (Bindlib.unbox r.Ast.rule_cons) ))
(Ast.RuleMap.bindings (Ast.RuleMap.filter is_func def))
@ List.map
(fun (_, r) ->
( Some "This definition is not a function:",
Pos.get_position (Bindlib.unbox r.Ast.cons) ))
Pos.get_position (Bindlib.unbox r.Ast.rule_cons) ))
(Ast.RuleMap.bindings (Ast.RuleMap.filter (fun n r -> not (is_func n r)) def)))
in
let top_list = def_map_to_tree def_info def in
@ -150,9 +151,10 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
(fun vertex ->
match vertex with
| Dependency.Vertex.Var (var : Scopelang.Ast.ScopeVar.t) ->
let var_def, var_typ, is_cond =
Ast.ScopeDefMap.find (Ast.ScopeDef.Var var) scope.scope_defs
in
let scope_def = Ast.ScopeDefMap.find (Ast.ScopeDef.Var var) scope.scope_defs in
let var_def = scope_def.scope_def_rules in
let var_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in
let expr_def = translate_def (Ast.ScopeDef.Var var) var_def var_typ is_cond in
[
Scopelang.Ast.Definition
@ -170,7 +172,10 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
in
let sub_scope_vars_redefs =
Ast.ScopeDefMap.mapi
(fun def_key (def, def_typ, is_cond) ->
(fun def_key scope_def ->
let def = scope_def.Ast.scope_def_rules in
let def_typ = scope_def.scope_def_typ in
let is_cond = scope_def.scope_def_is_condition in
match def_key with
| Ast.ScopeDef.Var _ -> assert false (* should not happen *)
| Ast.ScopeDef.SubScopeVar (_, sub_scope_var) ->
@ -211,7 +216,7 @@ let translate_scope (scope : Ast.scope) : Scopelang.Ast.scope_decl =
let scope_sig =
Scopelang.Ast.ScopeVarSet.fold
(fun var acc ->
let _, typ, _ = Ast.ScopeDefMap.find (Ast.ScopeDef.Var var) scope.scope_defs in
let typ = (Ast.ScopeDefMap.find (Ast.ScopeDef.Var var) scope.scope_defs).scope_def_typ in
Scopelang.Ast.ScopeVarMap.add var typ acc)
scope.scope_vars Scopelang.Ast.ScopeVarMap.empty
in

View File

@ -881,8 +881,9 @@ let process_default (ctxt : Name_resolution.context) (scope : Scopelang.Ast.Scop
(def_key : Desugared.Ast.ScopeDef.t Pos.marked) (rule_id : Desugared.Ast.RuleName.t)
(param_uid : Scopelang.Ast.Var.t Pos.marked option)
(precond : Scopelang.Ast.expr Pos.marked Bindlib.box option)
(exception_to_rules : Pos.t Desugared.Ast.RuleMap.t) (just : Ast.expression Pos.marked option)
(cons : Ast.expression Pos.marked) : Desugared.Ast.rule =
(exception_to_rules : Desugared.Ast.RuleSet.t Pos.marked)
(just : Ast.expression Pos.marked option) (cons : Ast.expression Pos.marked) :
Desugared.Ast.rule =
let just = match just with Some just -> Some (translate_expr scope ctxt just) | None -> None in
let just = merge_conditions precond just (Pos.get_position def_key) in
let cons = translate_expr scope ctxt cons in
@ -940,22 +941,19 @@ let process_def (precond : Scopelang.Ast.expr Pos.marked Bindlib.box option)
let rule_name = def.definition_id in
let parent_rules =
match def.Ast.definition_exception_to with
| NotAnException -> Desugared.Ast.RuleMap.empty
| NotAnException -> (Desugared.Ast.RuleSet.empty, Pos.get_position def.Ast.definition_name)
| UnlabeledException -> (
match scope_def_ctxt.default_exception_rulename with
(* This should have been caught previously by check_unlabeled_exception *)
| None | Some (Name_resolution.Ambiguous _) -> assert false
| Some (Name_resolution.Unique name) ->
Desugared.Ast.RuleMap.singleton name (Pos.get_position def.Ast.definition_name))
(Desugared.Ast.RuleSet.singleton name, Pos.get_position def.Ast.definition_name))
| ExceptionToLabel label -> (
try
Desugared.Ast.RuleSet.fold
(fun parent_rule (acc : Pos.t Desugared.Ast.RuleMap.t) ->
Desugared.Ast.RuleMap.add parent_rule (Pos.get_position label) acc)
(Desugared.Ast.LabelMap.find
(Desugared.Ast.IdentMap.find (Pos.unmark label) scope_def_ctxt.label_idmap)
scope_def.scope_def_label_groups)
Desugared.Ast.RuleMap.empty
( Desugared.Ast.LabelMap.find
(Desugared.Ast.IdentMap.find (Pos.unmark label) scope_def_ctxt.label_idmap)
scope_def.scope_def_label_groups,
Pos.get_position def.Ast.definition_name )
with Not_found ->
Errors.raise_spanned_error
(Format.asprintf "Unknown label for the scope variable %a: \"%s\""