From 709b51beb6e8bd1ecbedc90f19a4e4814ab21193 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Fri, 24 May 2024 17:24:14 +0200 Subject: [PATCH] Fix hashing of submodule references --- compiler/catala_utils/uid.ml | 14 +++++++------ compiler/catala_utils/uid.mli | 9 ++++++--- compiler/desugared/ast.ml | 32 +++++++++++++----------------- compiler/desugared/ast.mli | 12 +++++------ compiler/desugared/from_surface.ml | 31 ++++++++++++++--------------- compiler/shared_ast/type.mli | 4 ++-- 6 files changed, 50 insertions(+), 52 deletions(-) diff --git a/compiler/catala_utils/uid.ml b/compiler/catala_utils/uid.ml index 508979bf..83c89e52 100644 --- a/compiler/catala_utils/uid.ml +++ b/compiler/catala_utils/uid.ml @@ -114,12 +114,14 @@ module Path = struct let equal = List.equal Module.equal let compare = List.compare Module.compare - let rec strip n p = - if n = 0 then p - else - match p with - | _ :: p -> strip (n - 1) p - | [] -> invalid_arg "Uid.Path.strip" + let strip prefix p0 = + let rec aux prefix p = + match prefix, p with + | pfx1 :: pfx, p1 :: p -> if Module.equal pfx1 p1 then aux pfx p else p0 + | [], p -> p + | _ -> p0 + in + aux prefix p0 end module QualifiedMarkedString = struct diff --git a/compiler/catala_utils/uid.mli b/compiler/catala_utils/uid.mli index afd76abf..a51e6cf4 100644 --- a/compiler/catala_utils/uid.mli +++ b/compiler/catala_utils/uid.mli @@ -90,6 +90,10 @@ module Path : sig val format : Format.formatter -> t -> unit val equal : t -> t -> bool val compare : t -> t -> int + + val strip : t -> t -> t + (** [strip pfx p] removed [pfx] from the start of [p]. if [p] doesn't start + with [pfx], it is returned unchanged *) end (** Same as [Gen] but also registers path information *) @@ -99,7 +103,6 @@ module Gen_qualified (_ : Style) () : sig val fresh : Path.t -> MarkedString.info -> t val path : t -> Path.t val get_info : t -> MarkedString.info - val hash : strip:int -> t -> Hash.t - (* [strip] strips that number of elements from the start of the path before - hashing *) + val hash : strip:Path.t -> t -> Hash.t + (* [strip] strips that prefix from the start of the path before hashing *) end diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 70f8df18..80bc4638 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -71,13 +71,17 @@ module ScopeDef = struct ScopeVar.format ppf (Mark.remove v); format_kind ppf k - let hash_kind = function - | Var None -> Hashtbl.hash `VarNone - | Var (Some st) -> Hashtbl.hash (`VarSome (StateName.id st)) - | SubScopeInput { var_within_origin_scope = v; _ } -> - Hashtbl.hash (`SubScopeInput (ScopeVar.id v)) + open Hash.Op - let hash (v, k) = Hashtbl.hash (ScopeVar.id (Mark.remove v), hash_kind k) + let hash_kind ~strip = function + | Var v -> !`Var % Hash.option StateName.hash v + | SubScopeInput { name; var_within_origin_scope } -> + !`SubScopeInput + % ScopeName.hash ~strip name + % ScopeVar.hash var_within_origin_scope + + let hash ~strip (v, k) = + Hash.Op.(ScopeVar.hash (Mark.remove v) % hash_kind ~strip k) end include Base @@ -288,24 +292,16 @@ module Hash = struct % !(d.scope_def_is_condition : bool) % io d.scope_def_io - let scope_def ~strip (var, kind) = - ScopeVar.hash (Mark.remove var) - % - match kind with - | ScopeDef.Var st -> Hash.option StateName.hash st - | ScopeDef.SubScopeInput { name; var_within_origin_scope } -> - ScopeName.hash ~strip name % ScopeVar.hash var_within_origin_scope - let scope ~strip s = Hash.map ScopeVar.Map.fold ScopeVar.hash var_or_state s.scope_vars % Hash.map ScopeVar.Map.fold ScopeVar.hash (ScopeName.hash ~strip) s.scope_sub_scopes % ScopeName.hash ~strip s.scope_uid - % Hash.map ScopeDef.Map.fold (scope_def ~strip) (scope_decl ~strip) + % Hash.map ScopeDef.Map.fold (ScopeDef.hash ~strip) (scope_decl ~strip) s.scope_defs (* assertions, options, etc. are not expected to be part of interfaces *) - let modul ?(strip = 0) m = + let modul ?(strip = []) m = Hash.map ScopeName.Map.fold (ScopeName.hash ~strip) (scope ~strip) (ScopeName.Map.filter (fun _ s -> s.scope_visibility = Public) @@ -316,8 +312,8 @@ module Hash = struct (fun _ td -> td.topdef_visibility = Public) m.module_topdefs) - let module_binding ?(root = false) modname m = - ModuleName.hash modname % modul ~strip:(if root then 0 else 1) m + let module_binding modname m = + ModuleName.hash modname % modul ~strip:[modname] m end let rec locations_used e : LocationSet.t = diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index d20212fe..9e23793b 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -32,7 +32,7 @@ module ScopeDef : sig val equal_kind : kind -> kind -> bool val compare_kind : kind -> kind -> int val format_kind : Format.formatter -> kind -> unit - val hash_kind : kind -> int + val hash_kind : strip:Uid.Path.t -> kind -> Hash.t type t = ScopeVar.t Mark.pos * kind @@ -40,7 +40,7 @@ module ScopeDef : sig val compare : t -> t -> int val get_position : t -> Pos.t val format : Format.formatter -> t -> unit - val hash : t -> int + val hash : strip:Uid.Path.t -> t -> Hash.t module Map : Map.S with type key = t module Set : Set.S with type elt = t @@ -154,11 +154,9 @@ module Hash : sig (** The [strip] argument below strips as many leading path components before hashing *) - val scope : strip:int -> scope -> Hash.t - val modul : ?strip:int -> modul -> Hash.t - - val module_binding : ?root:bool -> ModuleName.t -> modul -> Hash.t - (** This strips 1 path component by default unless [root] is [true] *) + val scope : strip:Uid.Path.t -> scope -> Hash.t + val modul : ?strip:Uid.Path.t -> modul -> Hash.t + val module_binding : ModuleName.t -> modul -> Hash.t end (** {1 Helpers} *) diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index ba369cf5..a6af6fee 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -1703,19 +1703,20 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : ModuleName.Map.mapi (fun mname mctx -> let m = - { - Ast.module_scopes = get_scopes mctx; - Ast.module_topdefs = - Ident.Map.fold - (fun _ name acc -> - let topdef_type, topdef_visibility = - TopdefName.Map.find name ctxt.Name_resolution.topdefs - in - TopdefName.Map.add name - { Ast.topdef_expr = None; topdef_visibility; topdef_type } - acc) - mctx.topdefs TopdefName.Map.empty; - } in + { + Ast.module_scopes = get_scopes mctx; + Ast.module_topdefs = + Ident.Map.fold + (fun _ name acc -> + let topdef_type, topdef_visibility = + TopdefName.Map.find name ctxt.Name_resolution.topdefs + in + TopdefName.Map.add name + { Ast.topdef_expr = None; topdef_visibility; topdef_type } + acc) + mctx.topdefs TopdefName.Map.empty; + } + in m, Ast.Hash.module_binding mname m) ctxt.modules in @@ -1816,7 +1817,5 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : (desugared.Ast.program_module_name |> Option.map @@ fun (mname, _) -> - ( mname, - Ast.Hash.module_binding ~root:true mname desugared.Ast.program_root ) - ); + mname, Ast.Hash.module_binding mname desugared.Ast.program_root); } diff --git a/compiler/shared_ast/type.mli b/compiler/shared_ast/type.mli index 023300f1..5d026da7 100644 --- a/compiler/shared_ast/type.mli +++ b/compiler/shared_ast/type.mli @@ -26,8 +26,8 @@ val equal : t -> t -> bool val equal_list : t list -> t list -> bool val compare : t -> t -> int -val hash : strip:int -> t -> Hash.t -(** The [strip] argument strips as many leading path components in included +val hash : strip:Uid.Path.t -> t -> Hash.t +(** The [strip] argument strips the given leading path components in included identifiers before hashing *) val unifiable : t -> t -> bool