From 403156b36e8dcceb536cff2d3a4df19d9d2c8368 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Fri, 24 May 2024 14:26:44 +0200 Subject: [PATCH] Computation and checking of module hashes This includes a few separate changes: - pass visibility information of declarations (depending on wether the declaration was in a ```catala-metadata block or not) - add reasonable hash computation functions to discriminate the interfaces. In particular: * Uids have a `hash` function that depends on their string, but not on their actual uid (which is not stable between runs of the compiler) ; the existing `hash` function and its uses have been renamed to `id`. * The `Hash` module provides the tools to properly combine hashes, etc. While we rely on `Hashtbl.hash` for the atoms, we take care not to use it on any recursive structure (it relies on a bounded traversal). - insert the hashes in the artefacts, and properly check and report those (for OCaml) **Remains to do**: - Record and check the hashes in the other backends - Provide a way to get stable inline-test outputs in the presence of module hashes - Provide a way to write external modules that don't break at every Catala update. --- compiler/catala_utils/hash.ml | 108 +++++++++++++++++++++++++ compiler/catala_utils/hash.mli | 73 +++++++++++++++++ compiler/catala_utils/mark.ml | 1 + compiler/catala_utils/mark.mli | 4 + compiler/catala_utils/string.ml | 1 + compiler/catala_utils/string.mli | 2 + compiler/catala_utils/uid.ml | 24 +++++- compiler/catala_utils/uid.mli | 16 +++- compiler/desugared/ast.ml | 80 ++++++++++++++++-- compiler/desugared/ast.mli | 27 ++++++- compiler/desugared/dependency.ml | 8 +- compiler/desugared/disambiguate.ml | 12 ++- compiler/desugared/from_surface.ml | 107 +++++++++++++++++------- compiler/desugared/name_resolution.ml | 98 +++++++++++++--------- compiler/desugared/name_resolution.mli | 8 +- compiler/driver.ml | 20 ++++- compiler/lcalc/to_ocaml.ml | 25 ++++-- compiler/lcalc/to_ocaml.mli | 2 + compiler/plugins/api_web.ml | 2 +- compiler/plugins/explain.ml | 5 +- compiler/plugins/lazy_interp.ml | 5 +- compiler/scalc/ast.ml | 2 +- compiler/scalc/from_lcalc.ml | 2 +- compiler/scalc/print.ml | 4 +- compiler/scalc/to_c.ml | 8 +- compiler/scalc/to_python.ml | 12 +-- compiler/scalc/to_r.ml | 8 +- compiler/scopelang/ast.ml | 2 +- compiler/scopelang/ast.mli | 7 +- compiler/scopelang/dependency.ml | 8 +- compiler/scopelang/from_desugared.ml | 8 +- compiler/shared_ast/definitions.ml | 7 +- compiler/shared_ast/interpreter.ml | 53 ++++++++---- compiler/shared_ast/interpreter.mli | 2 +- compiler/shared_ast/program.ml | 6 +- compiler/shared_ast/program.mli | 3 +- compiler/shared_ast/type.ml | 16 ++++ compiler/shared_ast/type.mli | 7 ++ compiler/shared_ast/typing.ml | 3 +- runtimes/ocaml/runtime.ml | 4 +- runtimes/ocaml/runtime.mli | 4 +- tests/modules/good/mod_def.catala_en | 8 ++ 42 files changed, 640 insertions(+), 162 deletions(-) create mode 100644 compiler/catala_utils/hash.ml create mode 100644 compiler/catala_utils/hash.mli diff --git a/compiler/catala_utils/hash.ml b/compiler/catala_utils/hash.ml new file mode 100644 index 00000000..7af724c8 --- /dev/null +++ b/compiler/catala_utils/hash.ml @@ -0,0 +1,108 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2024 Inria, contributor: + Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +type t = int + +let mix (h1 : t) (h2 : t) : t = Hashtbl.hash (h1, h2) +let raw = Hashtbl.hash + +module Op = struct + let ( % ) = mix + let ( ! ) = raw +end + +open Op + +let option f = function None -> !`None | Some x -> !`Some % f x +let list hf l = List.fold_left (fun acc x -> acc % hf x) !`ListEmpty l + +let map fold_fun kh vh map = + fold_fun (fun k v acc -> acc lxor (kh k % vh v)) map !`HashMapDelim + +module Flags : sig + type nonrec t = private t + + val pass : + (t -> 'a) -> + avoid_exceptions:bool -> + closure_conversion:bool -> + monomorphize_types:bool -> + 'a + + val of_t : int -> t +end = struct + type nonrec t = t + + let pass k ~avoid_exceptions ~closure_conversion ~monomorphize_types = + (* Should not affect the call convention or actual interfaces: include, + optimize, check_invariants, typed *) + !(avoid_exceptions : bool) + % !(closure_conversion : bool) + % !(monomorphize_types : bool) + % (* The following may not affect the call convention, but we want it set in + an homogeneous way *) + !(Global.options.trace : bool) + % !(Global.options.max_prec_digits : int) + |> k + + let of_t t = t +end + +type full = { catala_version : t; flags_hash : Flags.t; contents : t } + +let finalise t = + Flags.pass (fun flags_hash -> + { catala_version = !(Version.v : string); flags_hash; contents = t }) + +let to_string full = + Printf.sprintf "CM0|%08x|%08x|%08x" full.catala_version + (full.flags_hash :> int) + full.contents + +(* Putting color inside the hash makes them much easier to differentiate at a + glance *) +let format ppf full = + let open Ocolor_types in + let pcolor col f x = + Format.pp_open_stag ppf Ocolor_format.(Ocolor_style_tag (Fg (C24 col))); + f x; + Format.pp_close_stag ppf () + in + let tag = pcolor { r24 = 172; g24 = 172; b24 = 172 } in + let auto i = + { + r24 = 128 + (i mod 128); + g24 = 128 + ((i lsr 10) mod 128); + b24 = 128 + ((i lsr 20) mod 128); + } + in + let phash h = + let col = auto h in + pcolor col (Format.fprintf ppf "%08x") h + in + tag (Format.pp_print_string ppf) "CM0|"; + phash full.catala_version; + tag (Format.pp_print_string ppf) "|"; + phash (full.flags_hash :> int); + tag (Format.pp_print_string ppf) "|"; + phash full.contents + +let of_string s = + try + Scanf.sscanf s "CM0|%08x|%08x|%08x" + (fun catala_version flags_hash contents -> + { catala_version; flags_hash = Flags.of_t flags_hash; contents }) + with Scanf.Scan_failure _ -> failwith "Hash.of_string" diff --git a/compiler/catala_utils/hash.mli b/compiler/catala_utils/hash.mli new file mode 100644 index 00000000..89435c04 --- /dev/null +++ b/compiler/catala_utils/hash.mli @@ -0,0 +1,73 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2024 Inria, contributor: + Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +(** Hashes for the identification of modules. + + In contrast with OCaml's basic `Hashtbl.hash`, they process the full depth + of terms. Any meaningful interface change in a module should only be in hash + collision with a 1/2^30 probability. *) + +type t = private int +(** Native Hasthbl.hash hashes, value is truncated to 30 bits whatever the + architecture (positive 31-bit integers) *) + +type full +(** A "full" hash includes the Catala version and compilation flags, alongside + the module interface *) + +val raw : 'a -> t +(** [Hashtbl.hash]. Do not use on deep types (it has a bounded depth), use + specific hashing functions. *) + +module Op : sig + val ( ! ) : 'a -> t + (** Shortcut to [raw]. Same warning: use with an explicit type annotation + [!(foo: string)] to ensure it's not called on types that are recursive or + include annotations. + + Hint: we use [!`Foo] as a fancy way to generate constants for + discriminating constructions *) + + val ( % ) : t -> t -> t + (** Safe combination of two hashes (non commutative or associative, etc.) *) +end + +val option : ('a -> t) -> 'a option -> t +val list : ('a -> t) -> 'a list -> t + +val map : + (('k -> 'v -> t -> t) -> 'map -> t -> t) -> + ('k -> t) -> + ('v -> t) -> + 'map -> + t +(** [map fold_f key_hash_f value_hash_f map] computes the hash of a map. The + first argument is expected to be a [Foo.Map.fold] function. The result is + independent of the ordering of the map. *) + +val finalise : + t -> + avoid_exceptions:bool -> + closure_conversion:bool -> + monomorphize_types:bool -> + full +(** Turns a raw interface hash into a full hash, ready for printing *) + +val to_string : full -> string +val format : Format.formatter -> full -> unit + +val of_string : string -> full +(** @raise Failure *) diff --git a/compiler/catala_utils/mark.ml b/compiler/catala_utils/mark.ml index 8d50f263..ef1be20c 100644 --- a/compiler/catala_utils/mark.ml +++ b/compiler/catala_utils/mark.ml @@ -29,6 +29,7 @@ let fold f (x, _) = f x let fold2 f (x, _) (y, _) = f x y let compare cmp a b = fold2 cmp a b let equal eq a b = fold2 eq a b +let hash f (x, _) = f x class ['self] marked_map = object (_self : 'self) diff --git a/compiler/catala_utils/mark.mli b/compiler/catala_utils/mark.mli index 95a5ace5..ebbf44a7 100644 --- a/compiler/catala_utils/mark.mli +++ b/compiler/catala_utils/mark.mli @@ -41,6 +41,10 @@ val compare : ('a -> 'a -> int) -> ('a, 'm) ed -> ('a, 'm) ed -> int val equal : ('a -> 'a -> bool) -> ('a, 'm) ed -> ('a, 'm) ed -> bool (** Tests equality of two marked values {b ignoring marks} *) +val hash : ('a -> Hash.t) -> ('a, 'm) ed -> Hash.t +(** Computes the hash of the marked values using the given function + {b ignoring mark} *) + (** Visitors *) class ['self] marked_map : object ('self) diff --git a/compiler/catala_utils/string.ml b/compiler/catala_utils/string.ml index 44dc5e6f..0af8ec76 100644 --- a/compiler/catala_utils/string.ml +++ b/compiler/catala_utils/string.ml @@ -101,6 +101,7 @@ module Arg = struct end let compare = Arg.compare +let hash t = Hash.raw t module Set = Set.Make (Arg) module Map = Map.Make (Arg) diff --git a/compiler/catala_utils/string.mli b/compiler/catala_utils/string.mli index ee7248be..b16b6723 100644 --- a/compiler/catala_utils/string.mli +++ b/compiler/catala_utils/string.mli @@ -23,6 +23,8 @@ module Map : Map.S with type key = string val compare : string -> string -> int (** String comparison with natural ordering of numbers within strings *) +val hash : string -> Hash.t + val to_ascii : string -> string (** Removes all non-ASCII diacritics from a string by converting them to their base letter in the Latin alphabet. *) diff --git a/compiler/catala_utils/uid.ml b/compiler/catala_utils/uid.ml index 23beedae..508979bf 100644 --- a/compiler/catala_utils/uid.ml +++ b/compiler/catala_utils/uid.ml @@ -21,6 +21,7 @@ module type Info = sig val format : Format.formatter -> info -> unit val equal : info -> info -> bool val compare : info -> info -> int + val hash : info -> Hash.t end module type Id = sig @@ -33,7 +34,8 @@ module type Id = sig val equal : t -> t -> bool val format : Format.formatter -> t -> unit val to_string : t -> string - val hash : t -> int + val id : t -> int + val hash : t -> Hash.t module Set : Set.S with type elt = t module Map : Map.S with type key = t @@ -68,8 +70,9 @@ module Make (X : Info) (S : Style) () : Id with type info = X.info = struct { id = !counter; info } let get_info (uid : t) : X.info = uid.info - let hash (x : t) : int = x.id + let id (x : t) : int = x.id let to_string t = X.to_string t.info + let hash t = X.hash t.info module Set = Set.Make (Ordering) module Map = Map.Make (Ordering) @@ -84,6 +87,7 @@ module MarkedString = struct let format fmt i = String.format fmt (to_string i) let equal = Mark.equal String.equal let compare = Mark.compare String.compare + let hash = Mark.hash String.hash end module Gen (S : Style) () = Make (MarkedString) (S) () @@ -109,6 +113,13 @@ module Path = struct let to_string p = String.concat "." (List.map Module.to_string p) let equal = List.equal Module.equal let compare = List.compare Module.compare + + let rec strip n p = + if n = 0 then p + else + match p with + | _ :: p -> strip (n - 1) p + | [] -> invalid_arg "Uid.Path.strip" end module QualifiedMarkedString = struct @@ -125,12 +136,21 @@ module QualifiedMarkedString = struct let compare (p1, i1) (p2, i2) = match Path.compare p1 p2 with 0 -> MarkedString.compare i1 i2 | n -> n + + let hash (p, i) = + let open Hash.Op in + Hash.list Module.hash p % MarkedString.hash i end module Gen_qualified (S : Style) () = struct include Make (QualifiedMarkedString) (S) () let fresh path t = fresh (path, t) + + let hash ~strip t = + let p, i = get_info t in + QualifiedMarkedString.hash (Path.strip strip p, i) + let path t = fst (get_info t) let get_info t = snd (get_info t) end diff --git a/compiler/catala_utils/uid.mli b/compiler/catala_utils/uid.mli index 4ed2e8d7..afd76abf 100644 --- a/compiler/catala_utils/uid.mli +++ b/compiler/catala_utils/uid.mli @@ -28,6 +28,9 @@ module type Info = sig val compare : info -> info -> int (** Comparison disregards position *) + + val hash : info -> Hash.t + (** Hashing disregards position *) end module MarkedString : Info with type info = string Mark.pos @@ -48,7 +51,15 @@ module type Id = sig val equal : t -> t -> bool val format : Format.formatter -> t -> unit val to_string : t -> string - val hash : t -> int + + val id : t -> int + (** Returns the unique ID of the identifier *) + + val hash : t -> Hash.t + (** While [id] returns a unique ID valable for a given Uid instance within a + given run of catala, this is a raw hash of the identifier string. + Therefore, it may collide within a given program, but remains meaninful + across separate compilations. *) module Set : Set.S with type elt = t module Map : Map.S with type key = t @@ -88,4 +99,7 @@ module Gen_qualified (_ : Style) () : sig val fresh : Path.t -> MarkedString.info -> t val path : t -> Path.t val get_info : t -> MarkedString.info + val hash : strip:int -> t -> Hash.t + (* [strip] strips that number of elements from the start of the path before + hashing *) end diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 9ac16a16..70f8df18 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -72,11 +72,12 @@ module ScopeDef = struct format_kind ppf k let hash_kind = function - | Var None -> 0 - | Var (Some st) -> StateName.hash st - | SubScopeInput { var_within_origin_scope = v; _ } -> ScopeVar.hash v + | Var None -> Hashtbl.hash `VarNone + | Var (Some st) -> Hashtbl.hash (`VarSome (StateName.id st)) + | SubScopeInput { var_within_origin_scope = v; _ } -> + Hashtbl.hash (`SubScopeInput (ScopeVar.id v)) - let hash (v, k) = Int.logxor (ScopeVar.hash (Mark.remove v)) (hash_kind k) + let hash (v, k) = Hashtbl.hash (ScopeVar.id (Mark.remove v), hash_kind k) end include Base @@ -231,6 +232,8 @@ type scope_def = { type var_or_states = WholeVar | States of StateName.t list +(* If fields are added, make sure to consider including them in the hash + computations below *) type scope = { scope_vars : var_or_states ScopeVar.Map.t; scope_sub_scopes : ScopeName.t ScopeVar.Map.t; @@ -239,21 +242,84 @@ type scope = { scope_assertions : assertion AssertionName.Map.t; scope_options : catala_option Mark.pos list; scope_meta_assertions : meta_assertion list; + scope_visibility : visibility; +} + +type topdef = { + topdef_expr : expr option; + topdef_type : typ; + topdef_visibility : visibility; } type modul = { module_scopes : scope ScopeName.Map.t; - module_topdefs : (expr option * typ) TopdefName.Map.t; + module_topdefs : topdef TopdefName.Map.t; } type program = { - program_module_name : Ident.t Mark.pos option; + program_module_name : (ModuleName.t * Hash.t) option; program_ctx : decl_ctx; program_modules : modul ModuleName.Map.t; program_root : modul; program_lang : Global.backend_lang; } +module Hash = struct + open Hash.Op + + let var_or_state = function + | WholeVar -> !`WholeVar + | States s -> !`States % Hash.list StateName.hash s + + let io x = + !(Mark.remove x.io_input : Runtime.io_input) + % !(Mark.remove x.io_output : bool) + + let scope_decl ~strip d = + (* scope_def_rules is ignored (not part of the interface) *) + Type.hash ~strip d.scope_def_typ + % Hash.option + (fun (lst, _) -> + List.fold_left + (fun acc (name, ty) -> + acc % Uid.MarkedString.hash name % Type.hash ~strip ty) + !`SDparams lst) + d.scope_def_parameters + % !(d.scope_def_is_condition : bool) + % io d.scope_def_io + + let scope_def ~strip (var, kind) = + ScopeVar.hash (Mark.remove var) + % + match kind with + | ScopeDef.Var st -> Hash.option StateName.hash st + | ScopeDef.SubScopeInput { name; var_within_origin_scope } -> + ScopeName.hash ~strip name % ScopeVar.hash var_within_origin_scope + + let scope ~strip s = + Hash.map ScopeVar.Map.fold ScopeVar.hash var_or_state s.scope_vars + % Hash.map ScopeVar.Map.fold ScopeVar.hash (ScopeName.hash ~strip) + s.scope_sub_scopes + % ScopeName.hash ~strip s.scope_uid + % Hash.map ScopeDef.Map.fold (scope_def ~strip) (scope_decl ~strip) + s.scope_defs + (* assertions, options, etc. are not expected to be part of interfaces *) + + let modul ?(strip = 0) m = + Hash.map ScopeName.Map.fold (ScopeName.hash ~strip) (scope ~strip) + (ScopeName.Map.filter + (fun _ s -> s.scope_visibility = Public) + m.module_scopes) + % Hash.map TopdefName.Map.fold (TopdefName.hash ~strip) + (fun td -> Type.hash ~strip td.topdef_type) + (TopdefName.Map.filter + (fun _ td -> td.topdef_visibility = Public) + m.module_topdefs) + + let module_binding ?(root = false) modname m = + ModuleName.hash modname % modul ~strip:(if root then 0 else 1) m +end + let rec locations_used e : LocationSet.t = match e with | ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m) @@ -311,5 +377,5 @@ let fold_exprs ~(f : 'a -> expr -> 'a) ~(init : 'a) (p : program) : 'a = p.program_root.module_scopes init in TopdefName.Map.fold - (fun _ (e, _) acc -> Option.fold ~none:acc ~some:(f acc) e) + (fun _ tdef acc -> Option.fold ~none:acc ~some:(f acc) tdef.topdef_expr) p.program_root.module_topdefs acc diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index b8f787bf..d20212fe 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -123,16 +123,23 @@ type scope = { (** empty outside of the root module *) scope_options : catala_option Mark.pos list; scope_meta_assertions : meta_assertion list; + scope_visibility : visibility; +} + +type topdef = { + topdef_expr : expr option; (** Always [None] outside of the root module *) + topdef_type : typ; + topdef_visibility : visibility; + (** Necessarily [Public] outside of the root module *) } type modul = { module_scopes : scope ScopeName.Map.t; - module_topdefs : (expr option * typ) TopdefName.Map.t; - (** the expr is [None] outside of the root module *) + module_topdefs : topdef TopdefName.Map.t; } type program = { - program_module_name : Ident.t Mark.pos option; + program_module_name : (ModuleName.t * Hash.t) option; program_ctx : decl_ctx; program_modules : modul ModuleName.Map.t; (** Contains all submodules of the program, in a flattened structure *) @@ -140,6 +147,20 @@ type program = { program_lang : Global.backend_lang; } +(** {1 Interface hash computations} *) + +(** These hashes are computed on interfaces: only signatures are considered. *) +module Hash : sig + (** The [strip] argument below strips as many leading path components before + hashing *) + + val scope : strip:int -> scope -> Hash.t + val modul : ?strip:int -> modul -> Hash.t + + val module_binding : ?root:bool -> ModuleName.t -> modul -> Hash.t + (** This strips 1 path component by default unless [root] is [true] *) +end + (** {1 Helpers} *) val locations_used : expr -> LocationSet.t diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 51490cdf..3006560c 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -39,9 +39,9 @@ module Vertex = struct let hash x = match x with - | Var (x, None) -> ScopeVar.hash x - | Var (x, Some sx) -> Int.logxor (ScopeVar.hash x) (StateName.hash sx) - | Assertion a -> Ast.AssertionName.hash a + | Var (x, None) -> ScopeVar.id x + | Var (x, Some sx) -> Hashtbl.hash (ScopeVar.id x, StateName.id sx) + | Assertion a -> Hashtbl.hash (`Assert (Ast.AssertionName.id a)) let compare x y = match x, y with @@ -252,7 +252,7 @@ module ExceptionVertex = struct let hash (x : t) : int = RuleName.Map.fold - (fun r _ acc -> Int.logxor (RuleName.hash r) acc) + (fun r _ acc -> Hashtbl.hash (RuleName.id r, acc)) x.rules 0 let equal x y = compare x y = 0 diff --git a/compiler/desugared/disambiguate.ml b/compiler/desugared/disambiguate.ml index 024645f6..2de18f72 100644 --- a/compiler/desugared/disambiguate.ml +++ b/compiler/desugared/disambiguate.ml @@ -98,10 +98,14 @@ let program prg = in let module_topdefs = TopdefName.Map.map - (function - | Some e, ty -> - Some (Expr.unbox (expr prg.program_ctx env (Expr.box e))), ty - | None, ty -> None, ty) + (fun def -> + { + def with + topdef_expr = + Option.map + (fun e -> Expr.unbox (expr prg.program_ctx env (Expr.box e))) + def.topdef_expr; + }) prg.program_root.module_topdefs in let module_scopes = diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index d9af0fc5..12fa0ab3 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -313,7 +313,7 @@ let rec translate_expr in let e2 = rec_helper ~local_vars e2 in Expr.make_abs [| binding_var |] e2 [tau] pos_op) - (EnumName.Map.find enum_uid ctxt.enums) + (fst (EnumName.Map.find enum_uid ctxt.enums)) in Expr.ematch ~e:(rec_helper e1_sub) ~name:enum_uid ~cases emark | Binop ((((S.And | S.Or | S.Xor), _) as op), e1, e2) -> @@ -613,7 +613,7 @@ let rec translate_expr StructField.Map.add f_uid f_e s_fields) StructField.Map.empty fields in - let expected_s_fields = StructName.Map.find s_uid ctxt.structs in + let expected_s_fields, _ = StructName.Map.find s_uid ctxt.structs in if StructField.Map.exists (fun expected_f _ -> not (StructField.Map.mem expected_f s_fields)) @@ -719,7 +719,7 @@ let rec translate_expr Expr.make_abs [| nop_var |] (Expr.elit (LBool (EnumConstructor.compare c_uid c_uid' = 0)) emark) [tau] pos) - (EnumName.Map.find enum_uid ctxt.enums) + (fst (EnumName.Map.find enum_uid ctxt.enums)) in Expr.ematch ~e:(rec_helper e1) ~name:enum_uid ~cases emark | ArrayLit es -> Expr.earray (List.map rec_helper es) emark @@ -970,7 +970,7 @@ and disambiguate_match_and_build_expression Expr.eabs e_binder [ EnumConstructor.Map.find c_uid - (EnumName.Map.find e_uid ctxt.Name_resolution.enums); + (fst (EnumName.Map.find e_uid ctxt.Name_resolution.enums)); ] (Mark.get case_body) in @@ -1037,6 +1037,7 @@ and disambiguate_match_and_build_expression if curr_index < nb_cases - 1 then raise_wildcard_not_last_case_err (); let missing_constructors = EnumName.Map.find e_uid ctxt.Name_resolution.enums + |> fst |> EnumConstructor.Map.filter_map (fun c_uid _ -> match EnumConstructor.Map.find_opt c_uid cases_d with | Some _ -> None @@ -1451,6 +1452,7 @@ let process_scope_use let process_topdef (ctxt : Name_resolution.context) (prgm : Ast.program) + (is_public : bool) (def : S.top_def) : Ast.program = let id = Ident.Map.find @@ -1493,12 +1495,14 @@ let process_topdef in Some (Expr.unbox_closed e) in + let topdef_visibility = if is_public then Public else Private in let module_topdefs = TopdefName.Map.update id (fun def0 -> match def0, expr_opt with - | None, eopt -> Some (eopt, typ) - | Some (eopt0, ty0), eopt -> ( + | None, eopt -> + Some { Ast.topdef_expr = eopt; topdef_visibility; topdef_type = typ } + | Some def0, eopt -> ( let err msg = Message.error ~extra_pos: @@ -1508,13 +1512,16 @@ let process_topdef ] (msg ^^ " for %a") TopdefName.format id in - if not (Type.equal ty0 typ) then err "Conflicting type definitions" + if not (Type.equal def0.Ast.topdef_type typ) then + err "Conflicting type definitions" else - match eopt0, eopt with + match def0.Ast.topdef_expr, eopt with | None, None -> err "Multiple declarations" | Some _, Some _ -> err "Multiple definitions" - | Some e, None -> Some (Some e, typ) - | None, Some e -> Some (Some e, ty0))) + | (Some _ as topdef_expr), None -> + Some { Ast.topdef_expr; topdef_visibility; topdef_type = typ } + | None, (Some _ as topdef_expr) -> + Some { def0 with Ast.topdef_expr })) prgm.Ast.program_root.module_topdefs in { prgm with program_root = { prgm.program_root with module_topdefs } } @@ -1680,6 +1687,7 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : scope_meta_assertions = []; scope_options = []; scope_uid = s_uid; + scope_visibility = s_context.Name_resolution.scope_visibility; } in let get_scopes mctx = @@ -1699,15 +1707,22 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : Ast.module_topdefs = Ident.Map.fold (fun _ name acc -> + let topdef_type, topdef_visibility = + TopdefName.Map.find name ctxt.Name_resolution.topdefs + in TopdefName.Map.add name - ( None, - TopdefName.Map.find name ctxt.Name_resolution.topdef_types - ) + { Ast.topdef_expr = None; topdef_visibility; topdef_type } acc) mctx.topdefs TopdefName.Map.empty; }) ctxt.modules in + let program_root = + { + Ast.module_scopes = get_scopes ctxt.Name_resolution.local; + Ast.module_topdefs = TopdefName.Map.empty; + } + in let program_ctx = let open Name_resolution in let ctx_scopes mctx acc = @@ -1720,23 +1735,36 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : in let ctx_modules = let rec aux mctx = - Ident.Map.fold - (fun _ m (M acc) -> - let sub = aux (ModuleName.Map.find m ctxt.modules) in - M (ModuleName.Map.add m sub acc)) - mctx.used_modules (M ModuleName.Map.empty) + let subs = + Ident.Map.fold + (fun _ m acc -> + let mctx = ModuleName.Map.find m ctxt.Name_resolution.modules in + let sub = aux mctx in + let mhash = + let intf = ModuleName.Map.find m program_modules in + Ast.Hash.module_binding m intf + (* We could include the hashes of submodule interfaces in the + hash of the module ; however, the module is already + responsible for checking the consistency of its dependencies + upon load, and that would result in harder to track errors on + mismatch. *) + in + ModuleName.Map.add m (mhash, sub) acc) + mctx.used_modules ModuleName.Map.empty + in + M subs in aux ctxt.local in { - ctx_structs = ctxt.structs; - ctx_enums = ctxt.enums; + ctx_structs = StructName.Map.map fst ctxt.structs; + ctx_enums = EnumName.Map.map fst ctxt.enums; ctx_scopes = ModuleName.Map.fold (fun _ -> ctx_scopes) ctxt.modules (ctx_scopes ctxt.local ScopeName.Map.empty); - ctx_topdefs = ctxt.topdef_types; + ctx_topdefs = TopdefName.Map.map fst ctxt.topdefs; ctx_struct_fields = ctxt.local.field_idmap; ctx_enum_constrs = ctxt.local.constructor_idmap; ctx_scope_index = @@ -1748,25 +1776,29 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : ctx_modules; } in + let program_module_name = + surface.Surface.Ast.program_module_name + |> Option.map + @@ fun id -> + let mname = ModuleName.fresh id in + let hash_placeholder = Hash.raw 0 in + mname, hash_placeholder + in let desugared = { Ast.program_lang = surface.program_lang; - Ast.program_module_name = surface.Surface.Ast.program_module_name; + Ast.program_module_name; Ast.program_modules; Ast.program_ctx; - Ast.program_root = - { - Ast.module_scopes = get_scopes ctxt.Name_resolution.local; - Ast.module_topdefs = TopdefName.Map.empty; - }; + Ast.program_root; } in - let process_code_block ctxt prgm block = + let process_code_block ctxt prgm is_meta block = List.fold_left (fun prgm item -> match Mark.remove item with | S.ScopeUse use -> process_scope_use ctxt prgm use - | S.Topdef def -> process_topdef ctxt prgm def + | S.Topdef def -> process_topdef ctxt prgm is_meta def | S.ScopeDecl _ | S.StructDecl _ | S.EnumDecl _ -> prgm) prgm block in @@ -1777,7 +1809,20 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : List.fold_left (fun prgm child -> process_structure prgm child) prgm children - | S.CodeBlock (block, _, _) -> process_code_block ctxt prgm block + | S.CodeBlock (block, _, is_meta) -> + process_code_block ctxt prgm is_meta block | S.ModuleDef _ | S.LawInclude _ | S.LawText _ | S.ModuleUse _ -> prgm in - List.fold_left process_structure desugared surface.S.program_items + let desugared = + List.fold_left process_structure desugared surface.S.program_items + in + { + desugared with + Ast.program_module_name = + (desugared.Ast.program_module_name + |> Option.map + @@ fun (mname, _) -> + ( mname, + Ast.Hash.module_binding ~root:true mname desugared.Ast.program_root ) + ); + } diff --git a/compiler/desugared/name_resolution.ml b/compiler/desugared/name_resolution.ml index 4c758094..f778f4e6 100644 --- a/compiler/desugared/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -39,6 +39,7 @@ type scope_context = { scope_out_struct : StructName.t; sub_scopes : ScopeName.Set.t; (** Other scopes referred to by this scope. Used for dependency analysis *) + scope_visibility : visibility; } (** Inside a scope, we distinguish between the variables and the subscopes. *) @@ -82,10 +83,11 @@ type module_context = { type context = { scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) - topdef_types : typ TopdefName.Map.t; - structs : struct_context StructName.Map.t; + topdefs : (typ * visibility) TopdefName.Map.t; + structs : (struct_context * visibility) StructName.Map.t; (** For each struct, its context *) - enums : enum_context EnumName.Map.t; (** For each enum, its context *) + enums : (enum_context * visibility) EnumName.Map.t; + (** For each enum, its context *) var_typs : var_sig ScopeVar.Map.t; (** The signatures of each scope variable declared *) modules : module_context ModuleName.Map.t; @@ -426,8 +428,10 @@ let process_data_decl } (** Process a struct declaration *) -let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : - context = +let process_struct_decl + ?(visibility = Public) + (ctxt : context) + (sdecl : Surface.Ast.struct_decl) : context = let s_uid = get_struct ctxt sdecl.struct_decl_name in if sdecl.struct_decl_fields = [] then Message.error @@ -454,25 +458,28 @@ let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : let ctxt = { ctxt with local } in let structs = StructName.Map.update s_uid - (fun fields -> - match fields with + (function | None -> Some - (StructField.Map.singleton f_uid - (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)) - | Some fields -> + ( StructField.Map.singleton f_uid + (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ), + visibility ) + | Some (fields, _) -> Some - (StructField.Map.add f_uid - (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ) - fields)) + ( StructField.Map.add f_uid + (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ) + fields, + visibility )) ctxt.structs in { ctxt with structs }) ctxt sdecl.struct_decl_fields (** Process an enum declaration *) -let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context - = +let process_enum_decl + ?(visibility = Public) + (ctxt : context) + (edecl : Surface.Ast.enum_decl) : context = let e_uid = get_enum ctxt edecl.enum_decl_name in if List.length edecl.enum_decl_cases = 0 then Message.error @@ -506,23 +513,24 @@ let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context | Some typ -> process_type ctxt typ in match cases with - | None -> Some (EnumConstructor.Map.singleton c_uid typ) - | Some fields -> Some (EnumConstructor.Map.add c_uid typ fields)) + | None -> Some (EnumConstructor.Map.singleton c_uid typ, visibility) + | Some (fields, _) -> + Some (EnumConstructor.Map.add c_uid typ fields, visibility)) ctxt.enums in { ctxt with enums }) ctxt edecl.enum_decl_cases -let process_topdef ctxt def = +let process_topdef ?(visibility = Public) ctxt def = let uid = Ident.Map.find (Mark.remove def.Surface.Ast.topdef_name) ctxt.local.topdefs in { ctxt with - topdef_types = + topdefs = TopdefName.Map.add uid - (process_type ctxt def.Surface.Ast.topdef_type) - ctxt.topdef_types; + (process_type ctxt def.Surface.Ast.topdef_type, visibility) + ctxt.topdefs; } (** Process an item declaration *) @@ -536,8 +544,10 @@ let process_item_decl process_subscope_decl scope ctxt sub_decl (** Process a scope declaration *) -let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : - context = +let process_scope_decl + ?(visibility = Public) + (ctxt : context) + (decl : Surface.Ast.scope_decl) : context = let scope_uid = get_scope ctxt decl.scope_decl_name in let ctxt = List.fold_left @@ -588,11 +598,12 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : structs = StructName.Map.add (get_struct ctxt decl.scope_decl_name) - StructField.Map.empty ctxt.structs; + (StructField.Map.empty, visibility) + ctxt.structs; } else let ctxt = - process_struct_decl ctxt + process_struct_decl ~visibility ctxt { struct_decl_name = decl.scope_decl_name; struct_decl_fields = output_fields; @@ -634,8 +645,10 @@ let typedef_info = function | TScope (s, _) -> ScopeName.get_info s (** Process the names of all declaration items *) -let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : - context = +let process_name_item + ?(visibility = Public) + (ctxt : context) + (item : Surface.Ast.code_item Mark.pos) : context = let raise_already_defined_error (use : Uid.MarkedString.info) name pos msg = Message.error ~fmt_pos: @@ -676,6 +689,7 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : scope_in_struct = in_struct_name; scope_out_struct = out_struct_name; sub_scopes = ScopeName.Set.empty; + scope_visibility = visibility; } ctxt.scopes in @@ -720,14 +734,16 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : { ctxt with local = { ctxt.local with topdefs } } (** Process a code item that is a declaration *) -let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : - context = +let process_decl_item + ?visibility + (ctxt : context) + (item : Surface.Ast.code_item Mark.pos) : context = match Mark.remove item with - | ScopeDecl decl -> process_scope_decl ctxt decl - | StructDecl sdecl -> process_struct_decl ctxt sdecl - | EnumDecl edecl -> process_enum_decl ctxt edecl + | ScopeDecl decl -> process_scope_decl ?visibility ctxt decl + | StructDecl sdecl -> process_struct_decl ?visibility ctxt sdecl + | EnumDecl edecl -> process_enum_decl ?visibility ctxt edecl | ScopeUse _ -> ctxt - | Topdef def -> process_topdef ctxt def + | Topdef def -> process_topdef ?visibility ctxt def (** Process a code block *) let process_code_block @@ -738,7 +754,11 @@ let process_code_block (** Process a law structure, only considering the code blocks *) let rec process_law_structure - (process_item : context -> Surface.Ast.code_item Mark.pos -> context) + (process_item : + ?visibility:visibility -> + context -> + Surface.Ast.code_item Mark.pos -> + context) (ctxt : context) (s : Surface.Ast.law_structure) : context = match s with @@ -746,8 +766,10 @@ let rec process_law_structure List.fold_left (fun ctxt child -> process_law_structure process_item ctxt child) ctxt children - | Surface.Ast.CodeBlock (block, _, _) -> - process_code_block process_item ctxt block + | Surface.Ast.CodeBlock (block, _, is_meta) -> + process_code_block + (process_item ~visibility:(if is_meta then Public else Private)) + ctxt block | Surface.Ast.LawInclude _ | Surface.Ast.LawText _ -> ctxt | Surface.Ast.ModuleDef _ | Surface.Ast.ModuleUse _ -> ctxt @@ -962,7 +984,7 @@ let empty_module_ctxt = let empty_ctxt = { scopes = ScopeName.Map.empty; - topdef_types = TopdefName.Map.empty; + topdefs = TopdefName.Map.empty; var_typs = ScopeVar.Map.empty; structs = StructName.Map.empty; enums = EnumName.Map.empty; @@ -1017,7 +1039,7 @@ let form_context (surface, mod_uses) surface_modules : context = in let ctxt = List.fold_left - (process_law_structure process_use_item) + (process_law_structure (fun ?visibility:_ -> process_use_item)) ctxt surface.Surface.Ast.program_items in (* Gather struct fields and enum constrs from direct modules: this helps with diff --git a/compiler/desugared/name_resolution.mli b/compiler/desugared/name_resolution.mli index ce54061d..918c9133 100644 --- a/compiler/desugared/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -39,6 +39,7 @@ type scope_context = { scope_out_struct : StructName.t; sub_scopes : ScopeName.Set.t; (** Other scopes referred to by this scope. Used for dependency analysis *) + scope_visibility : visibility; } (** Inside a scope, we distinguish between the variables and the subscopes. *) @@ -87,11 +88,12 @@ type module_context = { type context = { scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) - topdef_types : typ TopdefName.Map.t; + topdefs : (typ * visibility) TopdefName.Map.t; (** Types associated with the global definitions *) - structs : struct_context StructName.Map.t; + structs : (struct_context * visibility) StructName.Map.t; (** For each struct, its context *) - enums : enum_context EnumName.Map.t; (** For each enum, its context *) + enums : (enum_context * visibility) EnumName.Map.t; + (** For each enum, its context *) var_typs : var_sig ScopeVar.Map.t; (** The signatures of each scope variable declared *) modules : module_context ModuleName.Map.t; diff --git a/compiler/driver.ml b/compiler/driver.ml index adfc07ee..335d46e7 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -712,7 +712,12 @@ module Commands = struct let prg, _ = Passes.dcalc options ~includes ~optimize ~check_invariants ~typed in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules + ~hashf: + Hash.( + finalise ~avoid_exceptions:false ~closure_conversion:false + ~monomorphize_types:false) + prg; print_interpretation_results options Interpreter.interpret_program_dcalc prg (get_scopeopt_uid prg.decl_ctx ex_scope_opt) @@ -781,7 +786,10 @@ module Commands = struct Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion ~monomorphize_types ~typed in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules + ~hashf: + (Hash.finalise ~avoid_exceptions ~closure_conversion ~monomorphize_types) + prg; print_interpretation_results options Interpreter.interpret_program_lcalc prg (get_scopeopt_uid prg.decl_ctx ex_scope_opt) @@ -844,7 +852,11 @@ module Commands = struct Message.debug "Writing to %s..." (Option.value ~default:"stdout" output_file); let exec_scope = Option.map (get_scope_uid prg.decl_ctx) ex_scope_opt in - Lcalc.To_ocaml.format_program fmt prg ?exec_scope type_ordering + let hashf = + Hash.finalise ~avoid_exceptions ~closure_conversion:false + ~monomorphize_types:false + in + Lcalc.To_ocaml.format_program fmt prg ?exec_scope ~hashf type_ordering let ocaml_cmd = Cmd.v @@ -1038,7 +1050,7 @@ module Commands = struct in Format.open_hbox (); Format.pp_print_list ~pp_sep:Format.pp_print_space - (fun ppf m -> + (fun ppf (m, _h) -> let f = Pos.get_file (Mark.get (ModuleName.get_info m)) in let f = match prefix with diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 9b7a8e8b..1e3431a8 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -716,9 +716,16 @@ let commands = if commands = [] then entry_scopes else commands name format_var var name) scopes_with_no_input -let reexport_used_modules fmt modules = +let check_and_reexport_used_modules fmt ~hashf modules = List.iter - (fun m -> + (fun (m, hash) -> + Format.fprintf fmt + "@[let () =@ @[match Runtime_ocaml.Runtime.check_module \ + %S \"%a\"@ with@]@,\ + | Ok () -> ()@,\ + @[| Error h -> failwith \"Hash mismatch for module %a, it may \ + need recompiling\"@]@]@," + (ModuleName.to_string m) Hash.format (hashf hash) ModuleName.format m; Format.fprintf fmt "@[module %a@ = %a@]@," ModuleName.format m ModuleName.format m) modules @@ -726,7 +733,8 @@ let reexport_used_modules fmt modules = let format_module_registration fmt (bnd : ('m Ast.expr Var.t * _) String.Map.t) - modname = + modname + hash = Format.pp_open_vbox fmt 2; Format.pp_print_string fmt "let () ="; Format.pp_print_space fmt (); @@ -743,11 +751,13 @@ let format_module_registration (fun fmt (id, (var, _)) -> Format.fprintf fmt "@[%S,@ Obj.repr %a@]" id format_var var) fmt (String.Map.to_seq bnd); + (* TODO: pass the visibility info down from desugared, and filter what is + exported here *) Format.pp_close_box fmt (); Format.pp_print_char fmt ' '; Format.pp_print_string fmt "]"; Format.pp_print_space fmt (); - Format.pp_print_string fmt "\"todo-module-hash\""; + Format.fprintf fmt "\"%a\"" Hash.format hash; Format.pp_close_box fmt (); Format.pp_close_box fmt (); Format.pp_print_newline fmt () @@ -766,17 +776,20 @@ let format_program (fmt : Format.formatter) ?exec_scope ?(exec_args = true) + ~(hashf : Hash.t -> Hash.full) (p : 'm Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = Format.pp_open_vbox fmt 0; Format.pp_print_string fmt header; - reexport_used_modules fmt (Program.modules_to_list p.decl_ctx.ctx_modules); + check_and_reexport_used_modules fmt ~hashf + (Program.modules_to_list p.decl_ctx.ctx_modules); format_ctx type_ordering fmt p.decl_ctx; let bnd = format_code_items p.decl_ctx fmt p.code_items in Format.pp_print_cut fmt (); let () = match p.module_name, exec_scope with - | Some modname, None -> format_module_registration fmt bnd modname + | Some (modname, hash), None -> + format_module_registration fmt bnd modname (hashf hash) | None, Some scope_name -> let scope_body = Program.get_scope_body p scope_name in format_scope_exec p.decl_ctx fmt bnd scope_name scope_body diff --git a/compiler/lcalc/to_ocaml.mli b/compiler/lcalc/to_ocaml.mli index b6f6a9f5..85f49490 100644 --- a/compiler/lcalc/to_ocaml.mli +++ b/compiler/lcalc/to_ocaml.mli @@ -14,6 +14,7 @@ License for the specific language governing permissions and limitations under the License. *) +open Catala_utils open Shared_ast (** Formats a lambda calculus program into a valid OCaml program *) @@ -40,6 +41,7 @@ val format_program : Format.formatter -> ?exec_scope:ScopeName.t -> ?exec_args:bool -> + hashf:(Hash.t -> Hash.full) -> 'm Ast.program -> Scopelang.Dependency.TVertex.t list -> unit diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index 63cf61fc..73d1a22a 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -489,7 +489,7 @@ let run (Option.value ~default:"stdout" jsoo_output_file); let modname = match prg.module_name with - | Some m -> ModuleName.to_string m + | Some (m, _) -> ModuleName.to_string m | None -> String.capitalize_ascii Filename.( diff --git a/compiler/plugins/explain.ml b/compiler/plugins/explain.ml index f059eb18..929f15f9 100644 --- a/compiler/plugins/explain.ml +++ b/compiler/plugins/explain.ml @@ -1381,7 +1381,10 @@ let run includes optimize ex_scope explain_options global_options = Driver.Passes.dcalc global_options ~includes ~optimize ~check_invariants:false ~typed:Expr.typed in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules prg + ~hashf: + (Hash.finalise ~avoid_exceptions:false ~closure_conversion:false + ~monomorphize_types:false); let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in (* let result_expr, env = interpret_program prg scope in *) let g, base_vars, env = program_to_graph explain_options prg scope in diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml index 8ae7826e..7a2d4d02 100644 --- a/compiler/plugins/lazy_interp.ml +++ b/compiler/plugins/lazy_interp.ml @@ -271,7 +271,10 @@ let run includes optimize check_invariants ex_scope options = Driver.Passes.dcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules prg + ~hashf: + (Hash.finalise ~avoid_exceptions:false ~closure_conversion:false + ~monomorphize_types:false); let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in let result_expr, _env = interpret_program prg scope in let fmt = Format.std_formatter in diff --git a/compiler/scalc/ast.ml b/compiler/scalc/ast.ml index 7beb09fe..cf27db91 100644 --- a/compiler/scalc/ast.ml +++ b/compiler/scalc/ast.ml @@ -121,5 +121,5 @@ type ctx = { decl_ctx : decl_ctx; modules : VarName.t ModuleName.Map.t } type program = { ctx : ctx; code_items : code_item list; - module_name : ModuleName.t option; + module_name : (ModuleName.t * Hash.t) option; } diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index c80024b7..507802b9 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -659,7 +659,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : A.program = let modules = List.fold_left - (fun acc m -> + (fun acc (m, _hash) -> let vname = Mark.map (( ^ ) "Module_") (ModuleName.get_info m) in (* The "Module_" prefix is a workaround name clashes for same-name structs and modules, Python in particular mixes everything in one diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index 541cdbde..9cbf411b 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -21,10 +21,10 @@ open Ast let needs_parens (_e : expr) : bool = false let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit = - Format.fprintf fmt "%a_%d" VarName.format v (VarName.hash v) + Format.fprintf fmt "%a_%d" VarName.format v (VarName.id v) let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit = - Format.fprintf fmt "@{%a_%d@}" FuncName.format v (FuncName.hash v) + Format.fprintf fmt "@{%a_%d@}" FuncName.format v (FuncName.id v) let rec format_expr (decl_ctx : decl_ctx) diff --git a/compiler/scalc/to_c.ml b/compiler/scalc/to_c.ml index de698df5..b46541da 100644 --- a/compiler/scalc/to_c.ml +++ b/compiler/scalc/to_c.ml @@ -96,11 +96,11 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty let format_var (fmt : Format.formatter) (v : VarName.t) : unit = let v_str = Mark.remove (VarName.get_info v) in - let hash = VarName.hash v in + let id = VarName.id v in let local_id = match StringMap.find_opt v_str !string_counter_map with | Some ids -> ( - match IntMap.find_opt hash ids with + match IntMap.find_opt id ids with | None -> let max_id = snd @@ -111,13 +111,13 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit = in string_counter_map := StringMap.add v_str - (IntMap.add hash (max_id + 1) ids) + (IntMap.add id (max_id + 1) ids) !string_counter_map; max_id + 1 | Some local_id -> local_id) | None -> string_counter_map := - StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; + StringMap.add v_str (IntMap.singleton id 0) !string_counter_map; 0 in if v_str = "_" then Format.fprintf fmt "dummy_var" diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 640094fe..c76a26c2 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -152,20 +152,20 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty let format_var (fmt : Format.formatter) (v : VarName.t) : unit = let v_str = clean_name (Mark.remove (VarName.get_info v)) in - let hash = VarName.hash v in + let id = VarName.id v in let local_id = match StringMap.find_opt v_str !string_counter_map with | Some ids -> ( - match IntMap.find_opt hash ids with + match IntMap.find_opt id ids with | None -> - let id = 1 + IntMap.fold (fun _ -> Int.max) ids 0 in + let local_id = 1 + IntMap.fold (fun _ -> Int.max) ids 0 in string_counter_map := - StringMap.add v_str (IntMap.add hash id ids) !string_counter_map; - id + StringMap.add v_str (IntMap.add id local_id ids) !string_counter_map; + local_id | Some local_id -> local_id) | None -> string_counter_map := - StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; + StringMap.add v_str (IntMap.singleton id 0) !string_counter_map; 0 in if v_str = "_" then Format.fprintf fmt "_" diff --git a/compiler/scalc/to_r.ml b/compiler/scalc/to_r.ml index 1a368b0f..eb1e82b9 100644 --- a/compiler/scalc/to_r.ml +++ b/compiler/scalc/to_r.ml @@ -220,11 +220,11 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty let format_var (fmt : Format.formatter) (v : VarName.t) : unit = let v_str = Mark.remove (VarName.get_info v) in - let hash = VarName.hash v in + let id = VarName.id v in let local_id = match StringMap.find_opt v_str !string_counter_map with | Some ids -> ( - match IntMap.find_opt hash ids with + match IntMap.find_opt id ids with | None -> let max_id = snd @@ -235,13 +235,13 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit = in string_counter_map := StringMap.add v_str - (IntMap.add hash (max_id + 1) ids) + (IntMap.add id (max_id + 1) ids) !string_counter_map; max_id + 1 | Some local_id -> local_id) | None -> string_counter_map := - StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; + StringMap.add v_str (IntMap.singleton id 0) !string_counter_map; 0 in if v_str = "_" then Format.fprintf fmt "dummy_var" diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index 100cfc10..ff59f580 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -67,7 +67,7 @@ type 'm scope_decl = { } type 'm program = { - program_module_name : ModuleName.t option; + program_module_name : (ModuleName.t * Hash.t) option; program_ctx : decl_ctx; program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t; program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; diff --git a/compiler/scopelang/ast.mli b/compiler/scopelang/ast.mli index c47d1336..4f611369 100644 --- a/compiler/scopelang/ast.mli +++ b/compiler/scopelang/ast.mli @@ -63,12 +63,13 @@ type 'm scope_decl = { } type 'm program = { - program_module_name : ModuleName.t option; + program_module_name : (ModuleName.t * Hash.t) option; program_ctx : decl_ctx; program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t; (* Using [nil] here ensure that program interfaces don't contain any - expressions. They won't contain any rules or topdefs, but will still have - the scope signatures needed to respect the call convention *) + expressions. They won't contain any rules or topdef implementations, but + will still have the scope signatures needed to respect the call + convention *) program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; program_topdefs : ('m expr * typ) TopdefName.Map.t; program_lang : Global.backend_lang; diff --git a/compiler/scopelang/dependency.ml b/compiler/scopelang/dependency.ml index de664409..fff9363d 100644 --- a/compiler/scopelang/dependency.ml +++ b/compiler/scopelang/dependency.ml @@ -42,9 +42,7 @@ module SVertex = struct | Topdef g1, Topdef g2 -> TopdefName.equal g1 g2 | (Scope _ | Topdef _), _ -> false - let hash = function - | Scope s -> ScopeName.hash s - | Topdef g -> TopdefName.hash g + let hash = function Scope s -> ScopeName.id s | Topdef g -> TopdefName.id g let format ppf = function | Scope s -> ScopeName.format ppf s @@ -206,7 +204,9 @@ module TVertex = struct type t = Struct of StructName.t | Enum of EnumName.t let hash x = - match x with Struct x -> StructName.hash x | Enum x -> EnumName.hash x + match x with + | Struct x -> StructName.id x + | Enum x -> Hashtbl.hash (`Enum (EnumName.id x)) let compare x y = match x, y with diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index 03fa983a..630d6f1e 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -953,8 +953,9 @@ let translate_program let program_topdefs = TopdefName.Map.mapi (fun id -> function - | Some e, ty -> Expr.unbox (translate_expr ctx e), ty - | None, (_, pos) -> + | { D.topdef_expr = Some e; topdef_type = ty; topdef_visibility = _ } -> + Expr.unbox (translate_expr ctx e), ty + | { D.topdef_expr = None; topdef_type = _, pos; _ } -> Message.error ~pos "No definition found for %a" TopdefName.format id) desugared.program_root.module_topdefs in @@ -964,8 +965,7 @@ let translate_program desugared.D.program_root.module_scopes in { - Ast.program_module_name = - Option.map ModuleName.fresh desugared.D.program_module_name; + Ast.program_module_name = desugared.D.program_module_name; Ast.program_topdefs; Ast.program_scopes; Ast.program_ctx = ctx.decl_ctx; diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index dcb80d0e..8596a05b 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -669,7 +669,10 @@ type scope_info = { } (** In practice, this is a DAG: beware of repeated names *) -type module_tree = M of module_tree ModuleName.Map.t [@@caml.unboxed] +type module_tree = M of (Hash.t * module_tree) ModuleName.Map.t +[@@caml.unboxed] + +type visibility = Private | Public type decl_ctx = { ctx_enums : enum_ctx; @@ -688,5 +691,5 @@ type 'e program = { decl_ctx : decl_ctx; code_items : 'e code_item_list; lang : Global.backend_lang; - module_name : ModuleName.t option; + module_name : (ModuleName.t * Hash.t) option; } diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index 5e73afa3..634009db 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -1155,29 +1155,52 @@ let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list reflect that. *) let evaluate_expr ctx lang e = evaluate_expr ctx lang (addcustom e) -let load_runtime_modules prg = - let load m = +let load_runtime_modules ~hashf prg = + let load (m, mod_hash) = + let hash = hashf mod_hash in let obj_file = Dynlink.adapt_filename File.(Pos.get_file (Mark.get (ModuleName.get_info m)) -.- "cmo") in - if not (Sys.file_exists obj_file) then + (if not (Sys.file_exists obj_file) then + Message.error + ~pos_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here") + ~pos:(Mark.get (ModuleName.get_info m)) + "Compiled OCaml object %a@ not@ found.@ Make sure it has been \ + suitably compiled." + File.format obj_file + else + try Dynlink.loadfile obj_file + with Dynlink.Error dl_err -> + Message.error + "While loading compiled module from %a:@;<1 2>@[%a@]" + File.format obj_file Format.pp_print_text + (Dynlink.error_message dl_err)); + match + Runtime.check_module (ModuleName.to_string m) (Hash.to_string hash) + with + | Ok () -> () + | Error bad_hash -> + Message.debug + "Module hash mismatch for %a:@ @[Expected: %a@,Found: %a@]" + ModuleName.format m Hash.format hash + (fun ppf h -> + try Hash.format ppf (Hash.of_string h) + with Failure _ -> Format.pp_print_string ppf "") + bad_hash; Message.error - ~pos_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here") - ~pos:(Mark.get (ModuleName.get_info m)) - "Compiled OCaml object %a@ not@ found.@ Make sure it has been suitably \ - compiled." - File.format obj_file - else - try Dynlink.loadfile obj_file - with Dynlink.Error dl_err -> - Message.error "Error loading compiled module from %a:@;<1 2>@[%a@]" - File.format obj_file Format.pp_print_text - (Dynlink.error_message dl_err) + "Module %a@ needs@ recompiling:@ %a@ was@ likely@ compiled@ from@ an@ \ + older@ version@ or@ with@ incompatible@ flags." + ModuleName.format m File.format obj_file + | exception Not_found -> + Message.error + "Module %a@ was loaded from file %a but did not register properly, \ + there is something wrong in its code." + ModuleName.format m File.format obj_file in let modules_list_topo = Program.modules_to_list prg.decl_ctx.ctx_modules in if modules_list_topo <> [] then Message.debug "Loading shared modules... %a" (Format.pp_print_list ~pp_sep:Format.pp_print_space ModuleName.format) - modules_list_topo; + (List.map fst modules_list_topo); List.iter load modules_list_topo diff --git a/compiler/shared_ast/interpreter.mli b/compiler/shared_ast/interpreter.mli index b6a21894..f89c494e 100644 --- a/compiler/shared_ast/interpreter.mli +++ b/compiler/shared_ast/interpreter.mli @@ -62,6 +62,6 @@ val delcustom : (** Runtime check that the term contains no custom terms (raises [Invalid_argument] if that is the case *) -val load_runtime_modules : _ program -> unit +val load_runtime_modules : hashf:(Hash.t -> Hash.full) -> _ program -> unit (** Dynlink the runtime modules required by the given program, in order to make them callable by the interpreter. *) diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index d3a7f1f1..f3284336 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -89,9 +89,9 @@ let to_expr p main_scope = let modules_to_list (mt : module_tree) = let rec aux acc (M mtree) = ModuleName.Map.fold - (fun mname sub acc -> - if List.exists (ModuleName.equal mname) acc then acc - else mname :: aux acc sub) + (fun mname (subhash, sub) acc -> + if List.exists (fun (m, _) -> ModuleName.equal m mname) acc then acc + else (mname, subhash) :: aux acc sub) mtree acc in List.rev (aux [] mt) diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index 54a95047..6b4b2e38 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -15,6 +15,7 @@ License for the specific language governing permissions and limitations under the License. *) +open Catala_utils open Definitions (** {2 Program declaration context helpers} *) @@ -53,5 +54,5 @@ val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body -val modules_to_list : module_tree -> ModuleName.t list +val modules_to_list : module_tree -> (ModuleName.t * Hash.t) list (** Returns a list of used modules, in topological order *) diff --git a/compiler/shared_ast/type.ml b/compiler/shared_ast/type.ml index 29274ab9..a792e489 100644 --- a/compiler/shared_ast/type.ml +++ b/compiler/shared_ast/type.ml @@ -93,6 +93,22 @@ let rec compare ty1 ty2 = | TClosureEnv, _ -> -1 | _, TClosureEnv -> 1 +let rec hash ~strip ty = + let open Hash.Op in + match Mark.remove ty with + | TLit l -> !`TLit % !(l : typ_lit) + | TTuple tl -> List.fold_left (fun acc ty -> acc % hash ~strip ty) !`TTuple tl + | TStruct n -> !`TStruct % StructName.hash ~strip n + | TEnum n -> !`TEnum % EnumName.hash ~strip n + | TOption ty -> !`TOption % hash ~strip ty + | TArrow (tl, ty) -> + !`TArrow + % List.fold_left (fun acc ty -> acc % hash ~strip ty) (hash ~strip ty) tl + | TArray ty -> !`TArray % hash ~strip ty + | TDefault ty -> !`TDefault % hash ~strip ty + | TAny -> !`TAny + | TClosureEnv -> !`TClosureEnv + let rec arrow_return = function TArrow (_, b), _ -> arrow_return b | t -> t let format = Print.typ_debug diff --git a/compiler/shared_ast/type.mli b/compiler/shared_ast/type.mli index ac2ec56d..023300f1 100644 --- a/compiler/shared_ast/type.mli +++ b/compiler/shared_ast/type.mli @@ -14,6 +14,8 @@ License for the specific language governing permissions and limitations under the License. *) +open Catala_utils + type t = Definitions.typ val format : Format.formatter -> t -> unit @@ -23,6 +25,11 @@ module Map : Catala_utils.Map.S with type key = t val equal : t -> t -> bool val equal_list : t list -> t list -> bool val compare : t -> t -> int + +val hash : strip:int -> t -> Hash.t +(** The [strip] argument strips as many leading path components in included + identifiers before hashing *) + val unifiable : t -> t -> bool val unifiable_list : t list -> t list -> bool diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index acaf4abb..88ca7433 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -31,6 +31,7 @@ module Any = let format fmt () = Format.fprintf fmt "any" let equal () () = true let compare () () = 0 + let hash () = Hash.raw `Any end) (struct let style = Ocolor_types.(Fg (C4 hi_magenta)) @@ -166,7 +167,7 @@ let rec format_typ format_typ ~colors fmt t1; Format.pp_print_as fmt 1 "⟩" | TAny v -> - if Global.options.debug then Format.fprintf fmt "" (Any.hash v) + if Global.options.debug then Format.fprintf fmt "" (Any.id v) else Format.pp_print_string fmt "" | TClosureEnv -> Format.fprintf fmt "closure_env" diff --git a/runtimes/ocaml/runtime.ml b/runtimes/ocaml/runtime.ml index 4d626efd..60d279c5 100644 --- a/runtimes/ocaml/runtime.ml +++ b/runtimes/ocaml/runtime.ml @@ -897,7 +897,9 @@ let register_module modname values hash = Hashtbl.add modules_table modname hash; List.iter (fun (id, v) -> Hashtbl.add values_table ([modname], id) v) values -let check_module m h = String.equal (Hashtbl.find modules_table m) h +let check_module m h = + let h1 = Hashtbl.find modules_table m in + if String.equal h h1 then Ok () else Error h1 let lookup_value qid = try Hashtbl.find values_table qid diff --git a/runtimes/ocaml/runtime.mli b/runtimes/ocaml/runtime.mli index 9b9124b5..07b8550d 100644 --- a/runtimes/ocaml/runtime.mli +++ b/runtimes/ocaml/runtime.mli @@ -446,8 +446,8 @@ val register_module : string -> (string * Obj.t) list -> hash -> unit expected to be a hash of the source file and the Catala version, and will in time be used to ensure that the module and the interface are in sync *) -val check_module : string -> hash -> bool -(** Returns [true] if it has been registered with the correct hash, [false] if +val check_module : string -> hash -> (unit, hash) result +(** Returns [Ok] if it has been registered with the correct hash, [Error h] if there is a hash mismatch. @raise Not_found if the module does not exist at all *) diff --git a/tests/modules/good/mod_def.catala_en b/tests/modules/good/mod_def.catala_en index 4b0b1e84..65cd7131 100644 --- a/tests/modules/good/mod_def.catala_en +++ b/tests/modules/good/mod_def.catala_en @@ -19,12 +19,20 @@ declaration scope S: declaration half content decimal depends on x content integer equals x / 2 + +declaration maybe content Enum1 + depends on x content Enum1 ``` ```catala scope S: definition sr equals $1,000 definition e1 equals Maybe + + +declaration maybe content Enum1 + depends on x content Enum1 + equals Maybe ```