diff --git a/build_system/clerk_driver.ml b/build_system/clerk_driver.ml index f377c377..2237d103 100644 --- a/build_system/clerk_driver.ml +++ b/build_system/clerk_driver.ml @@ -548,8 +548,11 @@ let[@ocamlformat "disable"] static_base_rules = Nj.rule "out-test" ~command: [ !catala_exe; !test_command; "--plugin-dir="; "-o -"; !catala_flags; - !input; ">"; !output; "2>&1"; - "||"; "true"; + !input; "2>&1"; + "|"; "sed"; + "'s/\"CM0|[a-zA-Z0-9|]*\"/\"CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX\"/g'"; + ">"; !output; + "||"; "true" ] ~description: [""; "test"; !test_id; "⇐"; !input; "(" ^ !test_command ^ ")"]; diff --git a/build_system/clerk_runtest.ml b/build_system/clerk_runtest.ml index 4f262be8..b9daad31 100644 --- a/build_system/clerk_runtest.ml +++ b/build_system/clerk_runtest.ml @@ -16,10 +16,32 @@ open Catala_utils +let sanitize = + let re_endtest = Re.(compile @@ seq [bol; str "```"]) in + let re_modhash = + Re.( + compile + @@ seq + [ + str "\"CM0|"; + repn xdigit 8 (Some 8); + char '|'; + repn xdigit 8 (Some 8); + char '|'; + repn xdigit 8 (Some 8); + char '"'; + ]) + in + fun str -> + str + |> Re.replace_string re_endtest ~by:"\\```" + |> Re.replace_string re_modhash ~by:"\"CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX\"" + let run_catala_test test_flags catala_exe catala_opts file program args oc = - let cmd_in_rd, cmd_in_wr = Unix.pipe () in - Unix.set_close_on_exec cmd_in_wr; + let cmd_in_rd, cmd_in_wr = Unix.pipe ~cloexec:true () in + let cmd_out_rd, cmd_out_wr = Unix.pipe ~cloexec:true () in let command_oc = Unix.out_channel_of_descr cmd_in_wr in + let command_ic = Unix.in_channel_of_descr cmd_out_rd in let catala_exe = (* If the exe name contains directories, make it absolute. Otherwise don't modify it so that it can be looked up in PATH. *) @@ -59,12 +81,21 @@ let run_catala_test test_flags catala_exe catala_opts file program args oc = |> Seq.cons "CATALA_PLUGINS=" |> Array.of_seq in - flush oc; - let ocfd = Unix.descr_of_out_channel oc in - let pid = Unix.create_process_env catala_exe cmd env cmd_in_rd ocfd ocfd in + let pid = + Unix.create_process_env catala_exe cmd env cmd_in_rd cmd_out_wr cmd_out_wr + in Unix.close cmd_in_rd; + Unix.close cmd_out_wr; Seq.iter (output_string command_oc) program; close_out command_oc; + let out_lines = + Seq.of_dispenser (fun () -> In_channel.input_line command_ic) + in + Seq.iter + (fun line -> + output_string oc (sanitize line); + output_char oc '\n') + out_lines; let return_code = match Unix.waitpid [] pid with | _, Unix.WEXITED n -> n diff --git a/compiler/catala_utils/hash.ml b/compiler/catala_utils/hash.ml new file mode 100644 index 00000000..4854ffab --- /dev/null +++ b/compiler/catala_utils/hash.ml @@ -0,0 +1,110 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2024 Inria, contributor: + Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +type t = int + +let mix (h1 : t) (h2 : t) : t = Hashtbl.hash (h1, h2) +let raw = Hashtbl.hash + +module Op = struct + let ( % ) = mix + let ( ! ) = raw +end + +open Op + +let option f = function None -> !`None | Some x -> !`Some % f x +let list hf l = List.fold_left (fun acc x -> acc % hf x) !`ListEmpty l + +let map fold_fun kh vh map = + fold_fun (fun k v acc -> acc lxor (kh k % vh v)) map !`HashMapDelim + +module Flags : sig + type nonrec t = private t + + val pass : + (t -> 'a) -> + avoid_exceptions:bool -> + closure_conversion:bool -> + monomorphize_types:bool -> + 'a + + val of_t : int -> t +end = struct + type nonrec t = t + + let pass k ~avoid_exceptions ~closure_conversion ~monomorphize_types = + (* Should not affect the call convention or actual interfaces: include, + optimize, check_invariants, typed *) + !(avoid_exceptions : bool) + % !(closure_conversion : bool) + % !(monomorphize_types : bool) + % (* The following may not affect the call convention, but we want it set in + an homogeneous way *) + !(Global.options.trace : bool) + % !(Global.options.max_prec_digits : int) + |> k + + let of_t t = t +end + +type full = { catala_version : t; flags_hash : Flags.t; contents : t } + +let finalise t = + Flags.pass (fun flags_hash -> + { catala_version = !(Version.v : string); flags_hash; contents = t }) + +let to_string full = + Printf.sprintf "CM0|%08x|%08x|%08x" full.catala_version + (full.flags_hash :> int) + full.contents + +(* Putting color inside the hash makes them much easier to differentiate at a + glance *) +let format ppf full = + let open Ocolor_types in + let pcolor col f x = + Format.pp_open_stag ppf Ocolor_format.(Ocolor_style_tag (Fg (C24 col))); + f x; + Format.pp_close_stag ppf () + in + let tag = pcolor { r24 = 172; g24 = 172; b24 = 172 } in + let auto i = + { + r24 = 128 + (i mod 128); + g24 = 128 + ((i lsr 10) mod 128); + b24 = 128 + ((i lsr 20) mod 128); + } + in + let phash h = + let col = auto h in + pcolor col (Format.fprintf ppf "%08x") h + in + tag (Format.pp_print_string ppf) "CM0|"; + phash full.catala_version; + tag (Format.pp_print_string ppf) "|"; + phash (full.flags_hash :> int); + tag (Format.pp_print_string ppf) "|"; + phash full.contents + +let of_string s = + try + Scanf.sscanf s "CM0|%08x|%08x|%08x" + (fun catala_version flags_hash contents -> + { catala_version; flags_hash = Flags.of_t flags_hash; contents }) + with Scanf.Scan_failure _ -> failwith "Hash.of_string" + +let external_placeholder = "*external*" diff --git a/compiler/catala_utils/hash.mli b/compiler/catala_utils/hash.mli new file mode 100644 index 00000000..b70ba2f1 --- /dev/null +++ b/compiler/catala_utils/hash.mli @@ -0,0 +1,81 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2024 Inria, contributor: + Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +(** Hashes for the identification of modules. + + In contrast with OCaml's basic `Hashtbl.hash`, they process the full depth + of terms. Any meaningful interface change in a module should only be in hash + collision with a 1/2^30 probability. *) + +type t = private int +(** Native Hasthbl.hash hashes, value is truncated to 30 bits whatever the + architecture (positive 31-bit integers) *) + +type full +(** A "full" hash includes the Catala version and compilation flags, alongside + the module interface *) + +val raw : 'a -> t +(** [Hashtbl.hash]. Do not use on deep types (it has a bounded depth), use + specific hashing functions. *) + +module Op : sig + val ( ! ) : 'a -> t + (** Shortcut to [raw]. Same warning: use with an explicit type annotation + [!(foo: string)] to ensure it's not called on types that are recursive or + include annotations. + + Hint: we use [!`Foo] as a fancy way to generate constants for + discriminating constructions *) + + val ( % ) : t -> t -> t + (** Safe combination of two hashes (non commutative or associative, etc.) *) +end + +val option : ('a -> t) -> 'a option -> t +val list : ('a -> t) -> 'a list -> t + +val map : + (('k -> 'v -> t -> t) -> 'map -> t -> t) -> + ('k -> t) -> + ('v -> t) -> + 'map -> + t +(** [map fold_f key_hash_f value_hash_f map] computes the hash of a map. The + first argument is expected to be a [Foo.Map.fold] function. The result is + independent of the ordering of the map. *) + +val finalise : + t -> + avoid_exceptions:bool -> + closure_conversion:bool -> + monomorphize_types:bool -> + full +(** Turns a raw interface hash into a full hash, ready for printing *) + +val to_string : full -> string +val format : Format.formatter -> full -> unit + +val of_string : string -> full +(** @raise Failure *) + +val external_placeholder : string +(** It's inconvenient to need hash updates on external modules. This string is + uses as a hash instead for those cases. + + NOTE: This is a temporary solution A future approach could be to have Catala + generate a module loader (with the proper hash), relieving the user + implementation from having to do the registration. *) diff --git a/compiler/catala_utils/mark.ml b/compiler/catala_utils/mark.ml index 8d50f263..ef1be20c 100644 --- a/compiler/catala_utils/mark.ml +++ b/compiler/catala_utils/mark.ml @@ -29,6 +29,7 @@ let fold f (x, _) = f x let fold2 f (x, _) (y, _) = f x y let compare cmp a b = fold2 cmp a b let equal eq a b = fold2 eq a b +let hash f (x, _) = f x class ['self] marked_map = object (_self : 'self) diff --git a/compiler/catala_utils/mark.mli b/compiler/catala_utils/mark.mli index 95a5ace5..ebbf44a7 100644 --- a/compiler/catala_utils/mark.mli +++ b/compiler/catala_utils/mark.mli @@ -41,6 +41,10 @@ val compare : ('a -> 'a -> int) -> ('a, 'm) ed -> ('a, 'm) ed -> int val equal : ('a -> 'a -> bool) -> ('a, 'm) ed -> ('a, 'm) ed -> bool (** Tests equality of two marked values {b ignoring marks} *) +val hash : ('a -> Hash.t) -> ('a, 'm) ed -> Hash.t +(** Computes the hash of the marked values using the given function + {b ignoring mark} *) + (** Visitors *) class ['self] marked_map : object ('self) diff --git a/compiler/catala_utils/string.ml b/compiler/catala_utils/string.ml index 44dc5e6f..0af8ec76 100644 --- a/compiler/catala_utils/string.ml +++ b/compiler/catala_utils/string.ml @@ -101,6 +101,7 @@ module Arg = struct end let compare = Arg.compare +let hash t = Hash.raw t module Set = Set.Make (Arg) module Map = Map.Make (Arg) diff --git a/compiler/catala_utils/string.mli b/compiler/catala_utils/string.mli index ee7248be..b16b6723 100644 --- a/compiler/catala_utils/string.mli +++ b/compiler/catala_utils/string.mli @@ -23,6 +23,8 @@ module Map : Map.S with type key = string val compare : string -> string -> int (** String comparison with natural ordering of numbers within strings *) +val hash : string -> Hash.t + val to_ascii : string -> string (** Removes all non-ASCII diacritics from a string by converting them to their base letter in the Latin alphabet. *) diff --git a/compiler/catala_utils/uid.ml b/compiler/catala_utils/uid.ml index 23beedae..83c89e52 100644 --- a/compiler/catala_utils/uid.ml +++ b/compiler/catala_utils/uid.ml @@ -21,6 +21,7 @@ module type Info = sig val format : Format.formatter -> info -> unit val equal : info -> info -> bool val compare : info -> info -> int + val hash : info -> Hash.t end module type Id = sig @@ -33,7 +34,8 @@ module type Id = sig val equal : t -> t -> bool val format : Format.formatter -> t -> unit val to_string : t -> string - val hash : t -> int + val id : t -> int + val hash : t -> Hash.t module Set : Set.S with type elt = t module Map : Map.S with type key = t @@ -68,8 +70,9 @@ module Make (X : Info) (S : Style) () : Id with type info = X.info = struct { id = !counter; info } let get_info (uid : t) : X.info = uid.info - let hash (x : t) : int = x.id + let id (x : t) : int = x.id let to_string t = X.to_string t.info + let hash t = X.hash t.info module Set = Set.Make (Ordering) module Map = Map.Make (Ordering) @@ -84,6 +87,7 @@ module MarkedString = struct let format fmt i = String.format fmt (to_string i) let equal = Mark.equal String.equal let compare = Mark.compare String.compare + let hash = Mark.hash String.hash end module Gen (S : Style) () = Make (MarkedString) (S) () @@ -109,6 +113,15 @@ module Path = struct let to_string p = String.concat "." (List.map Module.to_string p) let equal = List.equal Module.equal let compare = List.compare Module.compare + + let strip prefix p0 = + let rec aux prefix p = + match prefix, p with + | pfx1 :: pfx, p1 :: p -> if Module.equal pfx1 p1 then aux pfx p else p0 + | [], p -> p + | _ -> p0 + in + aux prefix p0 end module QualifiedMarkedString = struct @@ -125,12 +138,21 @@ module QualifiedMarkedString = struct let compare (p1, i1) (p2, i2) = match Path.compare p1 p2 with 0 -> MarkedString.compare i1 i2 | n -> n + + let hash (p, i) = + let open Hash.Op in + Hash.list Module.hash p % MarkedString.hash i end module Gen_qualified (S : Style) () = struct include Make (QualifiedMarkedString) (S) () let fresh path t = fresh (path, t) + + let hash ~strip t = + let p, i = get_info t in + QualifiedMarkedString.hash (Path.strip strip p, i) + let path t = fst (get_info t) let get_info t = snd (get_info t) end diff --git a/compiler/catala_utils/uid.mli b/compiler/catala_utils/uid.mli index 4ed2e8d7..a51e6cf4 100644 --- a/compiler/catala_utils/uid.mli +++ b/compiler/catala_utils/uid.mli @@ -28,6 +28,9 @@ module type Info = sig val compare : info -> info -> int (** Comparison disregards position *) + + val hash : info -> Hash.t + (** Hashing disregards position *) end module MarkedString : Info with type info = string Mark.pos @@ -48,7 +51,15 @@ module type Id = sig val equal : t -> t -> bool val format : Format.formatter -> t -> unit val to_string : t -> string - val hash : t -> int + + val id : t -> int + (** Returns the unique ID of the identifier *) + + val hash : t -> Hash.t + (** While [id] returns a unique ID valable for a given Uid instance within a + given run of catala, this is a raw hash of the identifier string. + Therefore, it may collide within a given program, but remains meaninful + across separate compilations. *) module Set : Set.S with type elt = t module Map : Map.S with type key = t @@ -79,6 +90,10 @@ module Path : sig val format : Format.formatter -> t -> unit val equal : t -> t -> bool val compare : t -> t -> int + + val strip : t -> t -> t + (** [strip pfx p] removed [pfx] from the start of [p]. if [p] doesn't start + with [pfx], it is returned unchanged *) end (** Same as [Gen] but also registers path information *) @@ -88,4 +103,6 @@ module Gen_qualified (_ : Style) () : sig val fresh : Path.t -> MarkedString.info -> t val path : t -> Path.t val get_info : t -> MarkedString.info + val hash : strip:Path.t -> t -> Hash.t + (* [strip] strips that prefix from the start of the path before hashing *) end diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 9ac16a16..19ded35b 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -71,12 +71,17 @@ module ScopeDef = struct ScopeVar.format ppf (Mark.remove v); format_kind ppf k - let hash_kind = function - | Var None -> 0 - | Var (Some st) -> StateName.hash st - | SubScopeInput { var_within_origin_scope = v; _ } -> ScopeVar.hash v + open Hash.Op - let hash (v, k) = Int.logxor (ScopeVar.hash (Mark.remove v)) (hash_kind k) + let hash_kind ~strip = function + | Var v -> !`Var % Hash.option StateName.hash v + | SubScopeInput { name; var_within_origin_scope } -> + !`SubScopeInput + % ScopeName.hash ~strip name + % ScopeVar.hash var_within_origin_scope + + let hash ~strip (v, k) = + Hash.Op.(ScopeVar.hash (Mark.remove v) % hash_kind ~strip k) end include Base @@ -231,6 +236,8 @@ type scope_def = { type var_or_states = WholeVar | States of StateName.t list +(* If fields are added, make sure to consider including them in the hash + computations below *) type scope = { scope_vars : var_or_states ScopeVar.Map.t; scope_sub_scopes : ScopeName.t ScopeVar.Map.t; @@ -239,21 +246,76 @@ type scope = { scope_assertions : assertion AssertionName.Map.t; scope_options : catala_option Mark.pos list; scope_meta_assertions : meta_assertion list; + scope_visibility : visibility; +} + +type topdef = { + topdef_expr : expr option; + topdef_type : typ; + topdef_visibility : visibility; } type modul = { module_scopes : scope ScopeName.Map.t; - module_topdefs : (expr option * typ) TopdefName.Map.t; + module_topdefs : topdef TopdefName.Map.t; } type program = { - program_module_name : Ident.t Mark.pos option; + program_module_name : (ModuleName.t * module_intf_id) option; program_ctx : decl_ctx; program_modules : modul ModuleName.Map.t; program_root : modul; program_lang : Global.backend_lang; } +module Hash = struct + open Hash.Op + + let var_or_state = function + | WholeVar -> !`WholeVar + | States s -> !`States % Hash.list StateName.hash s + + let io x = + !(Mark.remove x.io_input : Runtime.io_input) + % !(Mark.remove x.io_output : bool) + + let scope_decl ~strip d = + (* scope_def_rules is ignored (not part of the interface) *) + Type.hash ~strip d.scope_def_typ + % Hash.option + (fun (lst, _) -> + List.fold_left + (fun acc (name, ty) -> + acc % Uid.MarkedString.hash name % Type.hash ~strip ty) + !`SDparams lst) + d.scope_def_parameters + % !(d.scope_def_is_condition : bool) + % io d.scope_def_io + + let scope ~strip s = + Hash.map ScopeVar.Map.fold ScopeVar.hash var_or_state s.scope_vars + % Hash.map ScopeVar.Map.fold ScopeVar.hash (ScopeName.hash ~strip) + s.scope_sub_scopes + % ScopeName.hash ~strip s.scope_uid + % Hash.map ScopeDef.Map.fold (ScopeDef.hash ~strip) (scope_decl ~strip) + s.scope_defs + (* assertions, options, etc. are not expected to be part of interfaces *) + + let modul ?(strip = []) m = + Hash.map ScopeName.Map.fold (ScopeName.hash ~strip) (scope ~strip) + (ScopeName.Map.filter + (fun _ s -> s.scope_visibility = Public) + m.module_scopes) + % Hash.map TopdefName.Map.fold (TopdefName.hash ~strip) + (fun td -> Type.hash ~strip td.topdef_type) + (TopdefName.Map.filter + (fun _ td -> td.topdef_visibility = Public) + m.module_topdefs) + + let module_binding modname m = + ModuleName.hash modname % modul ~strip:[modname] m +end + let rec locations_used e : LocationSet.t = match e with | ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m) @@ -311,5 +373,5 @@ let fold_exprs ~(f : 'a -> expr -> 'a) ~(init : 'a) (p : program) : 'a = p.program_root.module_scopes init in TopdefName.Map.fold - (fun _ (e, _) acc -> Option.fold ~none:acc ~some:(f acc) e) + (fun _ tdef acc -> Option.fold ~none:acc ~some:(f acc) tdef.topdef_expr) p.program_root.module_topdefs acc diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index b8f787bf..88f263b2 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -32,7 +32,7 @@ module ScopeDef : sig val equal_kind : kind -> kind -> bool val compare_kind : kind -> kind -> int val format_kind : Format.formatter -> kind -> unit - val hash_kind : kind -> int + val hash_kind : strip:Uid.Path.t -> kind -> Hash.t type t = ScopeVar.t Mark.pos * kind @@ -40,7 +40,7 @@ module ScopeDef : sig val compare : t -> t -> int val get_position : t -> Pos.t val format : Format.formatter -> t -> unit - val hash : t -> int + val hash : strip:Uid.Path.t -> t -> Hash.t module Map : Map.S with type key = t module Set : Set.S with type elt = t @@ -123,16 +123,23 @@ type scope = { (** empty outside of the root module *) scope_options : catala_option Mark.pos list; scope_meta_assertions : meta_assertion list; + scope_visibility : visibility; +} + +type topdef = { + topdef_expr : expr option; (** Always [None] outside of the root module *) + topdef_type : typ; + topdef_visibility : visibility; + (** Necessarily [Public] outside of the root module *) } type modul = { module_scopes : scope ScopeName.Map.t; - module_topdefs : (expr option * typ) TopdefName.Map.t; - (** the expr is [None] outside of the root module *) + module_topdefs : topdef TopdefName.Map.t; } type program = { - program_module_name : Ident.t Mark.pos option; + program_module_name : (ModuleName.t * module_intf_id) option; program_ctx : decl_ctx; program_modules : modul ModuleName.Map.t; (** Contains all submodules of the program, in a flattened structure *) @@ -140,6 +147,18 @@ type program = { program_lang : Global.backend_lang; } +(** {1 Interface hash computations} *) + +(** These hashes are computed on interfaces: only signatures are considered. *) +module Hash : sig + (** The [strip] argument below strips as many leading path components before + hashing *) + + val scope : strip:Uid.Path.t -> scope -> Hash.t + val modul : ?strip:Uid.Path.t -> modul -> Hash.t + val module_binding : ModuleName.t -> modul -> Hash.t +end + (** {1 Helpers} *) val locations_used : expr -> LocationSet.t diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 51490cdf..3006560c 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -39,9 +39,9 @@ module Vertex = struct let hash x = match x with - | Var (x, None) -> ScopeVar.hash x - | Var (x, Some sx) -> Int.logxor (ScopeVar.hash x) (StateName.hash sx) - | Assertion a -> Ast.AssertionName.hash a + | Var (x, None) -> ScopeVar.id x + | Var (x, Some sx) -> Hashtbl.hash (ScopeVar.id x, StateName.id sx) + | Assertion a -> Hashtbl.hash (`Assert (Ast.AssertionName.id a)) let compare x y = match x, y with @@ -252,7 +252,7 @@ module ExceptionVertex = struct let hash (x : t) : int = RuleName.Map.fold - (fun r _ acc -> Int.logxor (RuleName.hash r) acc) + (fun r _ acc -> Hashtbl.hash (RuleName.id r, acc)) x.rules 0 let equal x y = compare x y = 0 diff --git a/compiler/desugared/disambiguate.ml b/compiler/desugared/disambiguate.ml index 024645f6..2de18f72 100644 --- a/compiler/desugared/disambiguate.ml +++ b/compiler/desugared/disambiguate.ml @@ -98,10 +98,14 @@ let program prg = in let module_topdefs = TopdefName.Map.map - (function - | Some e, ty -> - Some (Expr.unbox (expr prg.program_ctx env (Expr.box e))), ty - | None, ty -> None, ty) + (fun def -> + { + def with + topdef_expr = + Option.map + (fun e -> Expr.unbox (expr prg.program_ctx env (Expr.box e))) + def.topdef_expr; + }) prg.program_root.module_topdefs in let module_scopes = diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index d9af0fc5..01d4ff47 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -313,7 +313,7 @@ let rec translate_expr in let e2 = rec_helper ~local_vars e2 in Expr.make_abs [| binding_var |] e2 [tau] pos_op) - (EnumName.Map.find enum_uid ctxt.enums) + (fst (EnumName.Map.find enum_uid ctxt.enums)) in Expr.ematch ~e:(rec_helper e1_sub) ~name:enum_uid ~cases emark | Binop ((((S.And | S.Or | S.Xor), _) as op), e1, e2) -> @@ -613,7 +613,7 @@ let rec translate_expr StructField.Map.add f_uid f_e s_fields) StructField.Map.empty fields in - let expected_s_fields = StructName.Map.find s_uid ctxt.structs in + let expected_s_fields, _ = StructName.Map.find s_uid ctxt.structs in if StructField.Map.exists (fun expected_f _ -> not (StructField.Map.mem expected_f s_fields)) @@ -719,7 +719,7 @@ let rec translate_expr Expr.make_abs [| nop_var |] (Expr.elit (LBool (EnumConstructor.compare c_uid c_uid' = 0)) emark) [tau] pos) - (EnumName.Map.find enum_uid ctxt.enums) + (fst (EnumName.Map.find enum_uid ctxt.enums)) in Expr.ematch ~e:(rec_helper e1) ~name:enum_uid ~cases emark | ArrayLit es -> Expr.earray (List.map rec_helper es) emark @@ -970,7 +970,7 @@ and disambiguate_match_and_build_expression Expr.eabs e_binder [ EnumConstructor.Map.find c_uid - (EnumName.Map.find e_uid ctxt.Name_resolution.enums); + (fst (EnumName.Map.find e_uid ctxt.Name_resolution.enums)); ] (Mark.get case_body) in @@ -1037,6 +1037,7 @@ and disambiguate_match_and_build_expression if curr_index < nb_cases - 1 then raise_wildcard_not_last_case_err (); let missing_constructors = EnumName.Map.find e_uid ctxt.Name_resolution.enums + |> fst |> EnumConstructor.Map.filter_map (fun c_uid _ -> match EnumConstructor.Map.find_opt c_uid cases_d with | Some _ -> None @@ -1451,6 +1452,7 @@ let process_scope_use let process_topdef (ctxt : Name_resolution.context) (prgm : Ast.program) + (is_public : bool) (def : S.top_def) : Ast.program = let id = Ident.Map.find @@ -1493,12 +1495,14 @@ let process_topdef in Some (Expr.unbox_closed e) in + let topdef_visibility = if is_public then Public else Private in let module_topdefs = TopdefName.Map.update id (fun def0 -> match def0, expr_opt with - | None, eopt -> Some (eopt, typ) - | Some (eopt0, ty0), eopt -> ( + | None, eopt -> + Some { Ast.topdef_expr = eopt; topdef_visibility; topdef_type = typ } + | Some def0, eopt -> ( let err msg = Message.error ~extra_pos: @@ -1508,13 +1512,16 @@ let process_topdef ] (msg ^^ " for %a") TopdefName.format id in - if not (Type.equal ty0 typ) then err "Conflicting type definitions" + if not (Type.equal def0.Ast.topdef_type typ) then + err "Conflicting type definitions" else - match eopt0, eopt with + match def0.Ast.topdef_expr, eopt with | None, None -> err "Multiple declarations" | Some _, Some _ -> err "Multiple definitions" - | Some e, None -> Some (Some e, typ) - | None, Some e -> Some (Some e, ty0))) + | (Some _ as topdef_expr), None -> + Some { Ast.topdef_expr; topdef_visibility; topdef_type = typ } + | None, (Some _ as topdef_expr) -> + Some { def0 with Ast.topdef_expr })) prgm.Ast.program_root.module_topdefs in { prgm with program_root = { prgm.program_root with module_topdefs } } @@ -1680,6 +1687,7 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : scope_meta_assertions = []; scope_options = []; scope_uid = s_uid; + scope_visibility = s_context.Name_resolution.scope_visibility; } in let get_scopes mctx = @@ -1692,22 +1700,32 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : mctx.Name_resolution.typedefs ScopeName.Map.empty in let program_modules = - ModuleName.Map.map - (fun mctx -> - { - Ast.module_scopes = get_scopes mctx; - Ast.module_topdefs = - Ident.Map.fold - (fun _ name acc -> - TopdefName.Map.add name - ( None, - TopdefName.Map.find name ctxt.Name_resolution.topdef_types - ) - acc) - mctx.topdefs TopdefName.Map.empty; - }) + ModuleName.Map.mapi + (fun mname mctx -> + let m = + { + Ast.module_scopes = get_scopes mctx; + Ast.module_topdefs = + Ident.Map.fold + (fun _ name acc -> + let topdef_type, topdef_visibility = + TopdefName.Map.find name ctxt.Name_resolution.topdefs + in + TopdefName.Map.add name + { Ast.topdef_expr = None; topdef_visibility; topdef_type } + acc) + mctx.topdefs TopdefName.Map.empty; + } + in + m, Ast.Hash.module_binding mname m) ctxt.modules in + let program_root = + { + Ast.module_scopes = get_scopes ctxt.Name_resolution.local; + Ast.module_topdefs = TopdefName.Map.empty; + } + in let program_ctx = let open Name_resolution in let ctx_scopes mctx acc = @@ -1720,23 +1738,30 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : in let ctx_modules = let rec aux mctx = - Ident.Map.fold - (fun _ m (M acc) -> - let sub = aux (ModuleName.Map.find m ctxt.modules) in - M (ModuleName.Map.add m sub acc)) - mctx.used_modules (M ModuleName.Map.empty) + let subs = + Ident.Map.fold + (fun _ m acc -> + let mctx = ModuleName.Map.find m ctxt.Name_resolution.modules in + let deps = aux mctx in + let hash = snd (ModuleName.Map.find m program_modules) in + ModuleName.Map.add m + { deps; intf_id = { hash; is_external = mctx.is_external } } + acc) + mctx.used_modules ModuleName.Map.empty + in + subs in aux ctxt.local in { - ctx_structs = ctxt.structs; - ctx_enums = ctxt.enums; + ctx_structs = StructName.Map.map fst ctxt.structs; + ctx_enums = EnumName.Map.map fst ctxt.enums; ctx_scopes = ModuleName.Map.fold (fun _ -> ctx_scopes) ctxt.modules (ctx_scopes ctxt.local ScopeName.Map.empty); - ctx_topdefs = ctxt.topdef_types; + ctx_topdefs = TopdefName.Map.map fst ctxt.topdefs; ctx_struct_fields = ctxt.local.field_idmap; ctx_enum_constrs = ctxt.local.constructor_idmap; ctx_scope_index = @@ -1748,25 +1773,29 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : ctx_modules; } in + let program_module_name = + surface.Surface.Ast.program_module + |> Option.map + @@ fun { Surface.Ast.module_name; module_external } -> + let mname = ModuleName.fresh module_name in + let hash_placeholder = Hash.raw 0 in + mname, { hash = hash_placeholder; is_external = module_external } + in let desugared = { Ast.program_lang = surface.program_lang; - Ast.program_module_name = surface.Surface.Ast.program_module_name; - Ast.program_modules; + Ast.program_module_name; + Ast.program_modules = ModuleName.Map.map fst program_modules; Ast.program_ctx; - Ast.program_root = - { - Ast.module_scopes = get_scopes ctxt.Name_resolution.local; - Ast.module_topdefs = TopdefName.Map.empty; - }; + Ast.program_root; } in - let process_code_block ctxt prgm block = + let process_code_block ctxt prgm is_meta block = List.fold_left (fun prgm item -> match Mark.remove item with | S.ScopeUse use -> process_scope_use ctxt prgm use - | S.Topdef def -> process_topdef ctxt prgm def + | S.Topdef def -> process_topdef ctxt prgm is_meta def | S.ScopeDecl _ | S.StructDecl _ | S.EnumDecl _ -> prgm) prgm block in @@ -1777,7 +1806,22 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : List.fold_left (fun prgm child -> process_structure prgm child) prgm children - | S.CodeBlock (block, _, _) -> process_code_block ctxt prgm block + | S.CodeBlock (block, _, is_meta) -> + process_code_block ctxt prgm is_meta block | S.ModuleDef _ | S.LawInclude _ | S.LawText _ | S.ModuleUse _ -> prgm in - List.fold_left process_structure desugared surface.S.program_items + let desugared = + List.fold_left process_structure desugared surface.S.program_items + in + { + desugared with + Ast.program_module_name = + (desugared.Ast.program_module_name + |> Option.map + @@ fun (mname, intf_id) -> + ( mname, + { + intf_id with + hash = Ast.Hash.module_binding mname desugared.Ast.program_root; + } )); + } diff --git a/compiler/desugared/name_resolution.ml b/compiler/desugared/name_resolution.ml index 4c758094..0352aafc 100644 --- a/compiler/desugared/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -39,6 +39,7 @@ type scope_context = { scope_out_struct : StructName.t; sub_scopes : ScopeName.Set.t; (** Other scopes referred to by this scope. Used for dependency analysis *) + scope_visibility : visibility; } (** Inside a scope, we distinguish between the variables and the subscopes. *) @@ -77,15 +78,17 @@ type module_context = { between different enums *) topdefs : TopdefName.t Ident.Map.t; (** Global definitions *) used_modules : ModuleName.t Ident.Map.t; + is_external : bool; } (** Context for name resolution, valid within a given module *) type context = { scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) - topdef_types : typ TopdefName.Map.t; - structs : struct_context StructName.Map.t; + topdefs : (typ * visibility) TopdefName.Map.t; + structs : (struct_context * visibility) StructName.Map.t; (** For each struct, its context *) - enums : enum_context EnumName.Map.t; (** For each enum, its context *) + enums : (enum_context * visibility) EnumName.Map.t; + (** For each enum, its context *) var_typs : var_sig ScopeVar.Map.t; (** The signatures of each scope variable declared *) modules : module_context ModuleName.Map.t; @@ -426,8 +429,10 @@ let process_data_decl } (** Process a struct declaration *) -let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : - context = +let process_struct_decl + ?(visibility = Public) + (ctxt : context) + (sdecl : Surface.Ast.struct_decl) : context = let s_uid = get_struct ctxt sdecl.struct_decl_name in if sdecl.struct_decl_fields = [] then Message.error @@ -454,25 +459,28 @@ let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : let ctxt = { ctxt with local } in let structs = StructName.Map.update s_uid - (fun fields -> - match fields with + (function | None -> Some - (StructField.Map.singleton f_uid - (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)) - | Some fields -> + ( StructField.Map.singleton f_uid + (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ), + visibility ) + | Some (fields, _) -> Some - (StructField.Map.add f_uid - (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ) - fields)) + ( StructField.Map.add f_uid + (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ) + fields, + visibility )) ctxt.structs in { ctxt with structs }) ctxt sdecl.struct_decl_fields (** Process an enum declaration *) -let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context - = +let process_enum_decl + ?(visibility = Public) + (ctxt : context) + (edecl : Surface.Ast.enum_decl) : context = let e_uid = get_enum ctxt edecl.enum_decl_name in if List.length edecl.enum_decl_cases = 0 then Message.error @@ -506,23 +514,24 @@ let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context | Some typ -> process_type ctxt typ in match cases with - | None -> Some (EnumConstructor.Map.singleton c_uid typ) - | Some fields -> Some (EnumConstructor.Map.add c_uid typ fields)) + | None -> Some (EnumConstructor.Map.singleton c_uid typ, visibility) + | Some (fields, _) -> + Some (EnumConstructor.Map.add c_uid typ fields, visibility)) ctxt.enums in { ctxt with enums }) ctxt edecl.enum_decl_cases -let process_topdef ctxt def = +let process_topdef ?(visibility = Public) ctxt def = let uid = Ident.Map.find (Mark.remove def.Surface.Ast.topdef_name) ctxt.local.topdefs in { ctxt with - topdef_types = + topdefs = TopdefName.Map.add uid - (process_type ctxt def.Surface.Ast.topdef_type) - ctxt.topdef_types; + (process_type ctxt def.Surface.Ast.topdef_type, visibility) + ctxt.topdefs; } (** Process an item declaration *) @@ -536,8 +545,10 @@ let process_item_decl process_subscope_decl scope ctxt sub_decl (** Process a scope declaration *) -let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : - context = +let process_scope_decl + ?(visibility = Public) + (ctxt : context) + (decl : Surface.Ast.scope_decl) : context = let scope_uid = get_scope ctxt decl.scope_decl_name in let ctxt = List.fold_left @@ -588,11 +599,12 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : structs = StructName.Map.add (get_struct ctxt decl.scope_decl_name) - StructField.Map.empty ctxt.structs; + (StructField.Map.empty, visibility) + ctxt.structs; } else let ctxt = - process_struct_decl ctxt + process_struct_decl ~visibility ctxt { struct_decl_name = decl.scope_decl_name; struct_decl_fields = output_fields; @@ -634,8 +646,10 @@ let typedef_info = function | TScope (s, _) -> ScopeName.get_info s (** Process the names of all declaration items *) -let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : - context = +let process_name_item + ?(visibility = Public) + (ctxt : context) + (item : Surface.Ast.code_item Mark.pos) : context = let raise_already_defined_error (use : Uid.MarkedString.info) name pos msg = Message.error ~fmt_pos: @@ -676,6 +690,7 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : scope_in_struct = in_struct_name; scope_out_struct = out_struct_name; sub_scopes = ScopeName.Set.empty; + scope_visibility = visibility; } ctxt.scopes in @@ -720,14 +735,16 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : { ctxt with local = { ctxt.local with topdefs } } (** Process a code item that is a declaration *) -let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : - context = +let process_decl_item + ?visibility + (ctxt : context) + (item : Surface.Ast.code_item Mark.pos) : context = match Mark.remove item with - | ScopeDecl decl -> process_scope_decl ctxt decl - | StructDecl sdecl -> process_struct_decl ctxt sdecl - | EnumDecl edecl -> process_enum_decl ctxt edecl + | ScopeDecl decl -> process_scope_decl ?visibility ctxt decl + | StructDecl sdecl -> process_struct_decl ?visibility ctxt sdecl + | EnumDecl edecl -> process_enum_decl ?visibility ctxt edecl | ScopeUse _ -> ctxt - | Topdef def -> process_topdef ctxt def + | Topdef def -> process_topdef ?visibility ctxt def (** Process a code block *) let process_code_block @@ -738,7 +755,11 @@ let process_code_block (** Process a law structure, only considering the code blocks *) let rec process_law_structure - (process_item : context -> Surface.Ast.code_item Mark.pos -> context) + (process_item : + ?visibility:visibility -> + context -> + Surface.Ast.code_item Mark.pos -> + context) (ctxt : context) (s : Surface.Ast.law_structure) : context = match s with @@ -746,10 +767,14 @@ let rec process_law_structure List.fold_left (fun ctxt child -> process_law_structure process_item ctxt child) ctxt children - | Surface.Ast.CodeBlock (block, _, _) -> - process_code_block process_item ctxt block + | Surface.Ast.CodeBlock (block, _, is_meta) -> + process_code_block + (process_item ~visibility:(if is_meta then Public else Private)) + ctxt block + | Surface.Ast.ModuleDef (_, is_external) -> + { ctxt with local = { ctxt.local with is_external } } | Surface.Ast.LawInclude _ | Surface.Ast.LawText _ -> ctxt - | Surface.Ast.ModuleDef _ | Surface.Ast.ModuleUse _ -> ctxt + | Surface.Ast.ModuleUse _ -> ctxt (** {1 Scope uses pass} *) @@ -957,12 +982,13 @@ let empty_module_ctxt = constructor_idmap = Ident.Map.empty; topdefs = Ident.Map.empty; used_modules = Ident.Map.empty; + is_external = false; } let empty_ctxt = { scopes = ScopeName.Map.empty; - topdef_types = TopdefName.Map.empty; + topdefs = TopdefName.Map.empty; var_typs = ScopeVar.Map.empty; structs = StructName.Map.empty; enums = EnumName.Map.empty; @@ -985,7 +1011,13 @@ let form_context (surface, mod_uses) surface_modules : context = let ctxt = { ctxt with - local = { ctxt.local with used_modules = mod_uses; path = [m] }; + local = + { + ctxt.local with + used_modules = mod_uses; + path = [m]; + is_external = intf.Surface.Ast.intf_modname.module_external; + }; } in let ctxt = @@ -1017,7 +1049,7 @@ let form_context (surface, mod_uses) surface_modules : context = in let ctxt = List.fold_left - (process_law_structure process_use_item) + (process_law_structure (fun ?visibility:_ -> process_use_item)) ctxt surface.Surface.Ast.program_items in (* Gather struct fields and enum constrs from direct modules: this helps with diff --git a/compiler/desugared/name_resolution.mli b/compiler/desugared/name_resolution.mli index ce54061d..99356bd1 100644 --- a/compiler/desugared/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -39,6 +39,7 @@ type scope_context = { scope_out_struct : StructName.t; sub_scopes : ScopeName.Set.t; (** Other scopes referred to by this scope. Used for dependency analysis *) + scope_visibility : visibility; } (** Inside a scope, we distinguish between the variables and the subscopes. *) @@ -82,16 +83,18 @@ type module_context = { topdefs : TopdefName.t Ident.Map.t; (** Global definitions *) used_modules : ModuleName.t Ident.Map.t; (** Module aliases and the modules they point to *) + is_external : bool; } (** Context for name resolution, valid within a given module *) type context = { scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) - topdef_types : typ TopdefName.Map.t; + topdefs : (typ * visibility) TopdefName.Map.t; (** Types associated with the global definitions *) - structs : struct_context StructName.Map.t; + structs : (struct_context * visibility) StructName.Map.t; (** For each struct, its context *) - enums : enum_context EnumName.Map.t; (** For each enum, its context *) + enums : (enum_context * visibility) EnumName.Map.t; + (** For each enum, its context *) var_typs : var_sig ScopeVar.Map.t; (** The signatures of each scope variable declared *) modules : module_context ModuleName.Map.t; diff --git a/compiler/driver.ml b/compiler/driver.ml index adfc07ee..6e0cd14a 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -93,7 +93,7 @@ let load_module_interfaces Surface.Parser_driver.load_interface ?default_module_name (Global.FileName f) in - let modname = ModuleName.fresh intf.intf_modname in + let modname = ModuleName.fresh intf.intf_modname.module_name in let seen = File.Map.add f None seen in let seen, sub_use_map = aux @@ -107,9 +107,9 @@ let load_module_interfaces (seen, Ident.Map.empty) uses in let seen = - match program.Surface.Ast.program_module_name with + match program.Surface.Ast.program_module with | Some m -> - let file = Pos.get_file (Mark.get m) in + let file = Pos.get_file (Mark.get m.module_name) in File.Map.singleton file None | None -> File.Map.empty in @@ -712,7 +712,12 @@ module Commands = struct let prg, _ = Passes.dcalc options ~includes ~optimize ~check_invariants ~typed in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules + ~hashf: + Hash.( + finalise ~avoid_exceptions:false ~closure_conversion:false + ~monomorphize_types:false) + prg; print_interpretation_results options Interpreter.interpret_program_dcalc prg (get_scopeopt_uid prg.decl_ctx ex_scope_opt) @@ -781,7 +786,10 @@ module Commands = struct Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion ~monomorphize_types ~typed in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules + ~hashf: + (Hash.finalise ~avoid_exceptions ~closure_conversion ~monomorphize_types) + prg; print_interpretation_results options Interpreter.interpret_program_lcalc prg (get_scopeopt_uid prg.decl_ctx ex_scope_opt) @@ -844,7 +852,11 @@ module Commands = struct Message.debug "Writing to %s..." (Option.value ~default:"stdout" output_file); let exec_scope = Option.map (get_scope_uid prg.decl_ctx) ex_scope_opt in - Lcalc.To_ocaml.format_program fmt prg ?exec_scope type_ordering + let hashf = + Hash.finalise ~avoid_exceptions ~closure_conversion:false + ~monomorphize_types:false + in + Lcalc.To_ocaml.format_program fmt prg ?exec_scope ~hashf type_ordering let ocaml_cmd = Cmd.v @@ -1010,7 +1022,7 @@ module Commands = struct let prg = Surface.Ast. { - program_module_name = None; + program_module = None; program_items = []; program_source_files = []; program_used_modules = @@ -1038,7 +1050,7 @@ module Commands = struct in Format.open_hbox (); Format.pp_print_list ~pp_sep:Format.pp_print_space - (fun ppf m -> + (fun ppf (m, _) -> let f = Pos.get_file (Mark.get (ModuleName.get_info m)) in let f = match prefix with diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 9b7a8e8b..e1c65611 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -716,9 +716,21 @@ let commands = if commands = [] then entry_scopes else commands name format_var var name) scopes_with_no_input -let reexport_used_modules fmt modules = +let check_and_reexport_used_modules fmt ~hashf modules = List.iter - (fun m -> + (fun (m, intf_id) -> + Format.fprintf fmt + "@[let () =@ @[match Runtime_ocaml.Runtime.check_module \ + %S \"%a\"@ with@]@,\ + | Ok () -> ()@,\ + @[| Error h -> failwith \"Hash mismatch for module %a, it may \ + need recompiling\"@]@]@," + (ModuleName.to_string m) + (fun ppf h -> + if intf_id.is_external then + Format.pp_print_string ppf Hash.external_placeholder + else Hash.format ppf h) + (hashf intf_id.hash) ModuleName.format m; Format.fprintf fmt "@[module %a@ = %a@]@," ModuleName.format m ModuleName.format m) modules @@ -726,7 +738,9 @@ let reexport_used_modules fmt modules = let format_module_registration fmt (bnd : ('m Ast.expr Var.t * _) String.Map.t) - modname = + modname + hash + is_external = Format.pp_open_vbox fmt 2; Format.pp_print_string fmt "let () ="; Format.pp_print_space fmt (); @@ -743,11 +757,17 @@ let format_module_registration (fun fmt (id, (var, _)) -> Format.fprintf fmt "@[%S,@ Obj.repr %a@]" id format_var var) fmt (String.Map.to_seq bnd); + (* TODO: pass the visibility info down from desugared, and filter what is + exported here *) Format.pp_close_box fmt (); Format.pp_print_char fmt ' '; Format.pp_print_string fmt "]"; Format.pp_print_space fmt (); - Format.pp_print_string fmt "\"todo-module-hash\""; + Format.fprintf fmt "\"%a\"" + (fun ppf h -> + if is_external then Format.pp_print_string ppf Hash.external_placeholder + else Hash.format ppf h) + hash; Format.pp_close_box fmt (); Format.pp_close_box fmt (); Format.pp_print_newline fmt () @@ -766,17 +786,21 @@ let format_program (fmt : Format.formatter) ?exec_scope ?(exec_args = true) + ~(hashf : Hash.t -> Hash.full) (p : 'm Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = Format.pp_open_vbox fmt 0; Format.pp_print_string fmt header; - reexport_used_modules fmt (Program.modules_to_list p.decl_ctx.ctx_modules); + check_and_reexport_used_modules fmt ~hashf + (Program.modules_to_list p.decl_ctx.ctx_modules); format_ctx type_ordering fmt p.decl_ctx; let bnd = format_code_items p.decl_ctx fmt p.code_items in Format.pp_print_cut fmt (); let () = match p.module_name, exec_scope with - | Some modname, None -> format_module_registration fmt bnd modname + | Some (modname, intf_id), None -> + format_module_registration fmt bnd modname (hashf intf_id.hash) + intf_id.is_external | None, Some scope_name -> let scope_body = Program.get_scope_body p scope_name in format_scope_exec p.decl_ctx fmt bnd scope_name scope_body diff --git a/compiler/lcalc/to_ocaml.mli b/compiler/lcalc/to_ocaml.mli index b6f6a9f5..85f49490 100644 --- a/compiler/lcalc/to_ocaml.mli +++ b/compiler/lcalc/to_ocaml.mli @@ -14,6 +14,7 @@ License for the specific language governing permissions and limitations under the License. *) +open Catala_utils open Shared_ast (** Formats a lambda calculus program into a valid OCaml program *) @@ -40,6 +41,7 @@ val format_program : Format.formatter -> ?exec_scope:ScopeName.t -> ?exec_args:bool -> + hashf:(Hash.t -> Hash.full) -> 'm Ast.program -> Scopelang.Dependency.TVertex.t list -> unit diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index 63cf61fc..73d1a22a 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -489,7 +489,7 @@ let run (Option.value ~default:"stdout" jsoo_output_file); let modname = match prg.module_name with - | Some m -> ModuleName.to_string m + | Some (m, _) -> ModuleName.to_string m | None -> String.capitalize_ascii Filename.( diff --git a/compiler/plugins/explain.ml b/compiler/plugins/explain.ml index f059eb18..929f15f9 100644 --- a/compiler/plugins/explain.ml +++ b/compiler/plugins/explain.ml @@ -1381,7 +1381,10 @@ let run includes optimize ex_scope explain_options global_options = Driver.Passes.dcalc global_options ~includes ~optimize ~check_invariants:false ~typed:Expr.typed in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules prg + ~hashf: + (Hash.finalise ~avoid_exceptions:false ~closure_conversion:false + ~monomorphize_types:false); let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in (* let result_expr, env = interpret_program prg scope in *) let g, base_vars, env = program_to_graph explain_options prg scope in diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml index 8ae7826e..7a2d4d02 100644 --- a/compiler/plugins/lazy_interp.ml +++ b/compiler/plugins/lazy_interp.ml @@ -271,7 +271,10 @@ let run includes optimize check_invariants ex_scope options = Driver.Passes.dcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules prg + ~hashf: + (Hash.finalise ~avoid_exceptions:false ~closure_conversion:false + ~monomorphize_types:false); let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in let result_expr, _env = interpret_program prg scope in let fmt = Format.std_formatter in diff --git a/compiler/scalc/ast.ml b/compiler/scalc/ast.ml index 7beb09fe..48c2b2da 100644 --- a/compiler/scalc/ast.ml +++ b/compiler/scalc/ast.ml @@ -121,5 +121,5 @@ type ctx = { decl_ctx : decl_ctx; modules : VarName.t ModuleName.Map.t } type program = { ctx : ctx; code_items : code_item list; - module_name : ModuleName.t option; + module_name : (ModuleName.t * module_intf_id) option; } diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index c80024b7..ca4933eb 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -659,7 +659,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) : A.program = let modules = List.fold_left - (fun acc m -> + (fun acc (m, _) -> let vname = Mark.map (( ^ ) "Module_") (ModuleName.get_info m) in (* The "Module_" prefix is a workaround name clashes for same-name structs and modules, Python in particular mixes everything in one diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index 541cdbde..9cbf411b 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -21,10 +21,10 @@ open Ast let needs_parens (_e : expr) : bool = false let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit = - Format.fprintf fmt "%a_%d" VarName.format v (VarName.hash v) + Format.fprintf fmt "%a_%d" VarName.format v (VarName.id v) let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit = - Format.fprintf fmt "@{%a_%d@}" FuncName.format v (FuncName.hash v) + Format.fprintf fmt "@{%a_%d@}" FuncName.format v (FuncName.id v) let rec format_expr (decl_ctx : decl_ctx) diff --git a/compiler/scalc/to_c.ml b/compiler/scalc/to_c.ml index de698df5..b46541da 100644 --- a/compiler/scalc/to_c.ml +++ b/compiler/scalc/to_c.ml @@ -96,11 +96,11 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty let format_var (fmt : Format.formatter) (v : VarName.t) : unit = let v_str = Mark.remove (VarName.get_info v) in - let hash = VarName.hash v in + let id = VarName.id v in let local_id = match StringMap.find_opt v_str !string_counter_map with | Some ids -> ( - match IntMap.find_opt hash ids with + match IntMap.find_opt id ids with | None -> let max_id = snd @@ -111,13 +111,13 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit = in string_counter_map := StringMap.add v_str - (IntMap.add hash (max_id + 1) ids) + (IntMap.add id (max_id + 1) ids) !string_counter_map; max_id + 1 | Some local_id -> local_id) | None -> string_counter_map := - StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; + StringMap.add v_str (IntMap.singleton id 0) !string_counter_map; 0 in if v_str = "_" then Format.fprintf fmt "dummy_var" diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 640094fe..c76a26c2 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -152,20 +152,20 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty let format_var (fmt : Format.formatter) (v : VarName.t) : unit = let v_str = clean_name (Mark.remove (VarName.get_info v)) in - let hash = VarName.hash v in + let id = VarName.id v in let local_id = match StringMap.find_opt v_str !string_counter_map with | Some ids -> ( - match IntMap.find_opt hash ids with + match IntMap.find_opt id ids with | None -> - let id = 1 + IntMap.fold (fun _ -> Int.max) ids 0 in + let local_id = 1 + IntMap.fold (fun _ -> Int.max) ids 0 in string_counter_map := - StringMap.add v_str (IntMap.add hash id ids) !string_counter_map; - id + StringMap.add v_str (IntMap.add id local_id ids) !string_counter_map; + local_id | Some local_id -> local_id) | None -> string_counter_map := - StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; + StringMap.add v_str (IntMap.singleton id 0) !string_counter_map; 0 in if v_str = "_" then Format.fprintf fmt "_" diff --git a/compiler/scalc/to_r.ml b/compiler/scalc/to_r.ml index 1a368b0f..eb1e82b9 100644 --- a/compiler/scalc/to_r.ml +++ b/compiler/scalc/to_r.ml @@ -220,11 +220,11 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty let format_var (fmt : Format.formatter) (v : VarName.t) : unit = let v_str = Mark.remove (VarName.get_info v) in - let hash = VarName.hash v in + let id = VarName.id v in let local_id = match StringMap.find_opt v_str !string_counter_map with | Some ids -> ( - match IntMap.find_opt hash ids with + match IntMap.find_opt id ids with | None -> let max_id = snd @@ -235,13 +235,13 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit = in string_counter_map := StringMap.add v_str - (IntMap.add hash (max_id + 1) ids) + (IntMap.add id (max_id + 1) ids) !string_counter_map; max_id + 1 | Some local_id -> local_id) | None -> string_counter_map := - StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; + StringMap.add v_str (IntMap.singleton id 0) !string_counter_map; 0 in if v_str = "_" then Format.fprintf fmt "dummy_var" diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index 100cfc10..07a0f6df 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -67,7 +67,7 @@ type 'm scope_decl = { } type 'm program = { - program_module_name : ModuleName.t option; + program_module_name : (ModuleName.t * module_intf_id) option; program_ctx : decl_ctx; program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t; program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; diff --git a/compiler/scopelang/ast.mli b/compiler/scopelang/ast.mli index c47d1336..d44f6ef0 100644 --- a/compiler/scopelang/ast.mli +++ b/compiler/scopelang/ast.mli @@ -63,12 +63,13 @@ type 'm scope_decl = { } type 'm program = { - program_module_name : ModuleName.t option; + program_module_name : (ModuleName.t * module_intf_id) option; program_ctx : decl_ctx; program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t; (* Using [nil] here ensure that program interfaces don't contain any - expressions. They won't contain any rules or topdefs, but will still have - the scope signatures needed to respect the call convention *) + expressions. They won't contain any rules or topdef implementations, but + will still have the scope signatures needed to respect the call + convention *) program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; program_topdefs : ('m expr * typ) TopdefName.Map.t; program_lang : Global.backend_lang; diff --git a/compiler/scopelang/dependency.ml b/compiler/scopelang/dependency.ml index de664409..fff9363d 100644 --- a/compiler/scopelang/dependency.ml +++ b/compiler/scopelang/dependency.ml @@ -42,9 +42,7 @@ module SVertex = struct | Topdef g1, Topdef g2 -> TopdefName.equal g1 g2 | (Scope _ | Topdef _), _ -> false - let hash = function - | Scope s -> ScopeName.hash s - | Topdef g -> TopdefName.hash g + let hash = function Scope s -> ScopeName.id s | Topdef g -> TopdefName.id g let format ppf = function | Scope s -> ScopeName.format ppf s @@ -206,7 +204,9 @@ module TVertex = struct type t = Struct of StructName.t | Enum of EnumName.t let hash x = - match x with Struct x -> StructName.hash x | Enum x -> EnumName.hash x + match x with + | Struct x -> StructName.id x + | Enum x -> Hashtbl.hash (`Enum (EnumName.id x)) let compare x y = match x, y with diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index 03fa983a..630d6f1e 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -953,8 +953,9 @@ let translate_program let program_topdefs = TopdefName.Map.mapi (fun id -> function - | Some e, ty -> Expr.unbox (translate_expr ctx e), ty - | None, (_, pos) -> + | { D.topdef_expr = Some e; topdef_type = ty; topdef_visibility = _ } -> + Expr.unbox (translate_expr ctx e), ty + | { D.topdef_expr = None; topdef_type = _, pos; _ } -> Message.error ~pos "No definition found for %a" TopdefName.format id) desugared.program_root.module_topdefs in @@ -964,8 +965,7 @@ let translate_program desugared.D.program_root.module_scopes in { - Ast.program_module_name = - Option.map ModuleName.fresh desugared.D.program_module_name; + Ast.program_module_name = desugared.D.program_module_name; Ast.program_topdefs; Ast.program_scopes; Ast.program_ctx = ctx.decl_ctx; diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index dcb80d0e..56dd2065 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -668,8 +668,14 @@ type scope_info = { out_struct_fields : StructField.t ScopeVar.Map.t; } +type module_intf_id = { hash : Hash.t; is_external : bool } + +type module_tree_node = { deps : module_tree; intf_id : module_intf_id } + +and module_tree = module_tree_node ModuleName.Map.t (** In practice, this is a DAG: beware of repeated names *) -type module_tree = M of module_tree ModuleName.Map.t [@@caml.unboxed] + +type visibility = Private | Public type decl_ctx = { ctx_enums : enum_ctx; @@ -688,5 +694,5 @@ type 'e program = { decl_ctx : decl_ctx; code_items : 'e code_item_list; lang : Global.backend_lang; - module_name : ModuleName.t option; + module_name : (ModuleName.t * module_intf_id) option; } diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index 5e73afa3..88837758 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -1155,29 +1155,57 @@ let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list reflect that. *) let evaluate_expr ctx lang e = evaluate_expr ctx lang (addcustom e) -let load_runtime_modules prg = - let load m = +let load_runtime_modules ~hashf prg = + let load (mname, intf_id) = + let hash = hashf intf_id.hash in + let expect_hash = + if intf_id.is_external then Hash.external_placeholder + else Hash.to_string hash + in let obj_file = Dynlink.adapt_filename - File.(Pos.get_file (Mark.get (ModuleName.get_info m)) -.- "cmo") + File.(Pos.get_file (Mark.get (ModuleName.get_info mname)) -.- "cmo") in - if not (Sys.file_exists obj_file) then + (if not (Sys.file_exists obj_file) then + Message.error + ~pos_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here") + ~pos:(Mark.get (ModuleName.get_info mname)) + "Compiled OCaml object %a@ not@ found.@ Make sure it has been \ + suitably compiled." + File.format obj_file + else + try Dynlink.loadfile obj_file + with Dynlink.Error dl_err -> + Message.error + "While loading compiled module from %a:@;<1 2>@[%a@]" + File.format obj_file Format.pp_print_text + (Dynlink.error_message dl_err)); + match Runtime.check_module (ModuleName.to_string mname) expect_hash with + | Ok () -> () + | Error bad_hash -> + Message.debug + "Module hash mismatch for %a:@ @[Expected: %a@,Found: %a@]" + ModuleName.format mname Hash.format hash + (fun ppf h -> + try Hash.format ppf (Hash.of_string h) + with Failure _ -> + if h = Hash.external_placeholder then + Format.fprintf ppf "@{%s@}" Hash.external_placeholder + else Format.fprintf ppf "@{@}") + bad_hash; Message.error - ~pos_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here") - ~pos:(Mark.get (ModuleName.get_info m)) - "Compiled OCaml object %a@ not@ found.@ Make sure it has been suitably \ - compiled." - File.format obj_file - else - try Dynlink.loadfile obj_file - with Dynlink.Error dl_err -> - Message.error "Error loading compiled module from %a:@;<1 2>@[%a@]" - File.format obj_file Format.pp_print_text - (Dynlink.error_message dl_err) + "Module %a@ needs@ recompiling:@ %a@ was@ likely@ compiled@ from@ an@ \ + older@ version@ or@ with@ incompatible@ flags." + ModuleName.format mname File.format obj_file + | exception Not_found -> + Message.error + "Module %a@ was loaded from file %a but did not register properly, \ + there is something wrong in its code." + ModuleName.format mname File.format obj_file in let modules_list_topo = Program.modules_to_list prg.decl_ctx.ctx_modules in if modules_list_topo <> [] then Message.debug "Loading shared modules... %a" (Format.pp_print_list ~pp_sep:Format.pp_print_space ModuleName.format) - modules_list_topo; + (List.map (fun (m, _) -> m) modules_list_topo); List.iter load modules_list_topo diff --git a/compiler/shared_ast/interpreter.mli b/compiler/shared_ast/interpreter.mli index b6a21894..f89c494e 100644 --- a/compiler/shared_ast/interpreter.mli +++ b/compiler/shared_ast/interpreter.mli @@ -62,6 +62,6 @@ val delcustom : (** Runtime check that the term contains no custom terms (raises [Invalid_argument] if that is the case *) -val load_runtime_modules : _ program -> unit +val load_runtime_modules : hashf:(Hash.t -> Hash.full) -> _ program -> unit (** Dynlink the runtime modules required by the given program, in order to make them callable by the interpreter. *) diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index d3a7f1f1..753a1c74 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -58,7 +58,7 @@ let empty_ctx = ctx_struct_fields = Ident.Map.empty; ctx_enum_constrs = Ident.Map.empty; ctx_scope_index = Ident.Map.empty; - ctx_modules = M ModuleName.Map.empty; + ctx_modules = ModuleName.Map.empty; } let get_scope_body { code_items; _ } scope = @@ -87,11 +87,11 @@ let to_expr p main_scope = res let modules_to_list (mt : module_tree) = - let rec aux acc (M mtree) = + let rec aux acc mtree = ModuleName.Map.fold - (fun mname sub acc -> - if List.exists (ModuleName.equal mname) acc then acc - else mname :: aux acc sub) + (fun mname mnode acc -> + if List.exists (fun (m, _) -> ModuleName.equal m mname) acc then acc + else (mname, mnode.intf_id) :: aux acc mnode.deps) mtree acc in List.rev (aux [] mt) diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index 54a95047..071b7873 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -53,5 +53,6 @@ val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body -val modules_to_list : module_tree -> ModuleName.t list -(** Returns a list of used modules, in topological order *) +val modules_to_list : module_tree -> (ModuleName.t * module_intf_id) list +(** Returns a list of used modules, in topological order ; the boolean indicates + if the module is external *) diff --git a/compiler/shared_ast/type.ml b/compiler/shared_ast/type.ml index 29274ab9..a792e489 100644 --- a/compiler/shared_ast/type.ml +++ b/compiler/shared_ast/type.ml @@ -93,6 +93,22 @@ let rec compare ty1 ty2 = | TClosureEnv, _ -> -1 | _, TClosureEnv -> 1 +let rec hash ~strip ty = + let open Hash.Op in + match Mark.remove ty with + | TLit l -> !`TLit % !(l : typ_lit) + | TTuple tl -> List.fold_left (fun acc ty -> acc % hash ~strip ty) !`TTuple tl + | TStruct n -> !`TStruct % StructName.hash ~strip n + | TEnum n -> !`TEnum % EnumName.hash ~strip n + | TOption ty -> !`TOption % hash ~strip ty + | TArrow (tl, ty) -> + !`TArrow + % List.fold_left (fun acc ty -> acc % hash ~strip ty) (hash ~strip ty) tl + | TArray ty -> !`TArray % hash ~strip ty + | TDefault ty -> !`TDefault % hash ~strip ty + | TAny -> !`TAny + | TClosureEnv -> !`TClosureEnv + let rec arrow_return = function TArrow (_, b), _ -> arrow_return b | t -> t let format = Print.typ_debug diff --git a/compiler/shared_ast/type.mli b/compiler/shared_ast/type.mli index ac2ec56d..5d026da7 100644 --- a/compiler/shared_ast/type.mli +++ b/compiler/shared_ast/type.mli @@ -14,6 +14,8 @@ License for the specific language governing permissions and limitations under the License. *) +open Catala_utils + type t = Definitions.typ val format : Format.formatter -> t -> unit @@ -23,6 +25,11 @@ module Map : Catala_utils.Map.S with type key = t val equal : t -> t -> bool val equal_list : t list -> t list -> bool val compare : t -> t -> int + +val hash : strip:Uid.Path.t -> t -> Hash.t +(** The [strip] argument strips the given leading path components in included + identifiers before hashing *) + val unifiable : t -> t -> bool val unifiable_list : t list -> t list -> bool diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index acaf4abb..88ca7433 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -31,6 +31,7 @@ module Any = let format fmt () = Format.fprintf fmt "any" let equal () () = true let compare () () = 0 + let hash () = Hash.raw `Any end) (struct let style = Ocolor_types.(Fg (C4 hi_magenta)) @@ -166,7 +167,7 @@ let rec format_typ format_typ ~colors fmt t1; Format.pp_print_as fmt 1 "⟩" | TAny v -> - if Global.options.debug then Format.fprintf fmt "" (Any.hash v) + if Global.options.debug then Format.fprintf fmt "" (Any.id v) else Format.pp_print_string fmt "" | TClosureEnv -> Format.fprintf fmt "closure_env" diff --git a/compiler/surface/ast.ml b/compiler/surface/ast.ml index 60f962ff..9a37fde1 100644 --- a/compiler/surface/ast.ml +++ b/compiler/surface/ast.ml @@ -318,7 +318,7 @@ and law_structure = | CodeBlock of code_block * source_repr * bool (* Metadata if true *) and interface = { - intf_modname : uident Mark.pos; + intf_modname : program_module; intf_code : code_block; (** Invariant: an interface shall only contain [*Decl] elements, or [Topdef] elements with [topdef_expr = None] *) @@ -330,8 +330,10 @@ and module_use = { mod_use_alias : uident Mark.pos; } +and program_module = { module_name : uident Mark.pos; module_external : bool } + and program = { - program_module_name : uident Mark.pos option; + program_module : program_module option; program_items : law_structure list; program_source_files : (string[@opaque]) list; program_used_modules : module_use list; diff --git a/compiler/surface/parser_driver.ml b/compiler/surface/parser_driver.ml index b1474823..672749d3 100644 --- a/compiler/surface/parser_driver.ml +++ b/compiler/surface/parser_driver.ml @@ -259,18 +259,21 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : List.fold_left (fun acc command -> let join_module_names name_opt = - match acc.Ast.program_module_name, name_opt with + match acc.Ast.program_module, name_opt with | opt, None | None, opt -> opt | Some id1, Some id2 -> Message.error - ~extra_pos:["", Mark.get id1; "", Mark.get id2] + ~extra_pos: + ["", Mark.get id1.module_name; "", Mark.get id2.module_name] "Multiple definitions of the module name" in match command with - | Ast.ModuleDef (id, _) -> + | Ast.ModuleDef (id, is_external) -> { acc with - Ast.program_module_name = join_module_names (Some id); + Ast.program_module = + join_module_names + (Some { module_name = id; module_external = is_external }); Ast.program_items = command :: acc.Ast.program_items; } | Ast.ModuleUse (mod_use_name, alias) -> @@ -288,22 +291,22 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : @@ fun lexbuf -> let includ_program = parse_source lexbuf in let () = - includ_program.Ast.program_module_name + includ_program.Ast.program_module |> Option.iter @@ fun id -> Message.error ~extra_pos: [ "File include", Mark.get inc_file; - "Module declaration", Mark.get id; + "Module declaration", Mark.get id.Ast.module_name; ] "A file that declares a module cannot be used through the raw \ '@{> Include@}'@ directive.@ You should use it as a \ module with@ '@{> Use @{%s@}@}'@ instead." - (Mark.remove id) + (Mark.remove id.Ast.module_name) in { - Ast.program_module_name = acc.program_module_name; + Ast.program_module = acc.program_module; Ast.program_source_files = List.rev_append includ_program.program_source_files acc.Ast.program_source_files; @@ -316,7 +319,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : } | Ast.LawHeading (heading, commands') -> let { - Ast.program_module_name; + Ast.program_module; Ast.program_items = commands'; Ast.program_source_files = new_sources; Ast.program_used_modules = new_used_modules; @@ -325,7 +328,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : expand_includes source_file commands' in { - Ast.program_module_name = join_module_names program_module_name; + Ast.program_module = join_module_names program_module; Ast.program_source_files = List.rev_append new_sources acc.Ast.program_source_files; Ast.program_items = @@ -336,7 +339,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : } | i -> { acc with Ast.program_items = i :: acc.Ast.program_items }) { - Ast.program_module_name = None; + Ast.program_module = None; Ast.program_source_files = []; Ast.program_items = []; Ast.program_used_modules = []; @@ -346,7 +349,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : in { Ast.program_lang = language; - Ast.program_module_name = rprg.Ast.program_module_name; + Ast.program_module = rprg.Ast.program_module; Ast.program_source_files = List.rev rprg.Ast.program_source_files; Ast.program_items = List.rev rprg.Ast.program_items; Ast.program_used_modules = List.rev rprg.Ast.program_used_modules; @@ -396,8 +399,8 @@ let with_sedlex_source source_file f = f lexbuf let check_modname program source_file = - match program.Ast.program_module_name, source_file with - | ( Some (mname, pos), + match program.Ast.program_module, source_file with + | ( Some { module_name = mname, pos; _ }, (Global.FileName file | Global.Contents (_, file) | Global.Stdin file) ) when not File.(equal mname Filename.(remove_extension (basename file))) -> Message.error ~pos @@ -413,10 +416,14 @@ let load_interface ?default_module_name source_file = let program = with_sedlex_source source_file parse_source in check_modname program source_file; let modname = - match program.Ast.program_module_name, default_module_name with + match program.Ast.program_module, default_module_name with | Some mname, _ -> mname | None, Some n -> - n, Pos.from_info (Global.input_src_file source_file) 0 0 0 0 + { + module_name = + n, Pos.from_info (Global.input_src_file source_file) 0 0 0 0; + module_external = false; + } | None, None -> Message.error "%a doesn't define a module name. It should contain a '@{> \ diff --git a/doc/devel/externals.md b/doc/devel/externals.md index 6d837c09..28c7cfbe 100644 --- a/doc/devel/externals.md +++ b/doc/devel/externals.md @@ -31,10 +31,10 @@ catala implementation and compile to OCaml (removing the `external` directive): ``` ```shell-session -$ clerk build _build/.../Prorata_external.ml +$ clerk build _build/.../prorata_external.ml ``` -(beware the `_build/`, and the capitalisation of the module name) +(beware the `_build/`, it is required here) ## Write the OCaml implementation @@ -44,9 +44,11 @@ capitalisation to match). Edit to replace the dummy implementation by your code. Refer to `runtimes/ocaml/runtime.mli` for what is available (especially the `Oper` module to manipulate the types). -Keep the `register_module` at the end as is, it's needed for the toplevel to use -the value (you would get `Failure("Could not resolve reference to Xxx")` during -evaluation). +Keep the `register_module` at the end, but replace the hash (which should be of +the form `"CM0|XXXXXXXX|XXXXXXXX|XXXXXXXX"`) by the string `"*external*"`. This +section is needed for the Catala interpreter to find the declared values --- the +error `Failure("Could not resolve reference to Xxx")` during evaluation is a +symptom that it is missing. ## Compile and test diff --git a/runtimes/ocaml/runtime.ml b/runtimes/ocaml/runtime.ml index 4d626efd..60d279c5 100644 --- a/runtimes/ocaml/runtime.ml +++ b/runtimes/ocaml/runtime.ml @@ -897,7 +897,9 @@ let register_module modname values hash = Hashtbl.add modules_table modname hash; List.iter (fun (id, v) -> Hashtbl.add values_table ([modname], id) v) values -let check_module m h = String.equal (Hashtbl.find modules_table m) h +let check_module m h = + let h1 = Hashtbl.find modules_table m in + if String.equal h h1 then Ok () else Error h1 let lookup_value qid = try Hashtbl.find values_table qid diff --git a/runtimes/ocaml/runtime.mli b/runtimes/ocaml/runtime.mli index 9b9124b5..07b8550d 100644 --- a/runtimes/ocaml/runtime.mli +++ b/runtimes/ocaml/runtime.mli @@ -446,8 +446,8 @@ val register_module : string -> (string * Obj.t) list -> hash -> unit expected to be a hash of the source file and the Catala version, and will in time be used to ensure that the module and the interface are in sync *) -val check_module : string -> hash -> bool -(** Returns [true] if it has been registered with the correct hash, [false] if +val check_module : string -> hash -> (unit, hash) result +(** Returns [Ok] if it has been registered with the correct hash, [Error h] if there is a hash mismatch. @raise Not_found if the module does not exist at all *) diff --git a/tests/modules/good/mod_def.catala_en b/tests/modules/good/mod_def.catala_en index 4b0b1e84..65cd7131 100644 --- a/tests/modules/good/mod_def.catala_en +++ b/tests/modules/good/mod_def.catala_en @@ -19,12 +19,20 @@ declaration scope S: declaration half content decimal depends on x content integer equals x / 2 + +declaration maybe content Enum1 + depends on x content Enum1 ``` ```catala scope S: definition sr equals $1,000 definition e1 equals Maybe + + +declaration maybe content Enum1 + depends on x content Enum1 + equals Maybe ``` diff --git a/tests/modules/good/output/mod_def.ml b/tests/modules/good/output/mod_def.ml index ebeabc56..d3ac9384 100644 --- a/tests/modules/good/output/mod_def.ml +++ b/tests/modules/good/output/mod_def.ml @@ -30,7 +30,7 @@ let s (s_in: S_in.t) : S.t = try (handle_default [|{filename="tests/modules/good/mod_def.catala_en"; - start_line=26; start_column=24; end_line=26; end_column=30; + start_line=29; start_column=24; end_line=29; end_column=30; law_headings=["Test modules + inclusions 1"]}|] ([|(fun (_: unit) -> handle_default [||] ([||]) (fun (_: unit) -> true) @@ -47,7 +47,7 @@ let s (s_in: S_in.t) : S.t = try (handle_default [|{filename="tests/modules/good/mod_def.catala_en"; - start_line=27; start_column=24; end_line=27; end_column=29; + start_line=30; start_column=24; end_line=30; end_column=29; law_headings=["Test modules + inclusions 1"]}|] ([|(fun (_: unit) -> handle_default [||] ([||]) (fun (_: unit) -> true) @@ -70,8 +70,12 @@ let half_ : integer -> decimal = law_headings=["Test modules + inclusions 1"]} x_ (integer_of_string "2") +let maybe_ : Enum1.t -> Enum1.t = + fun (_: Enum1.t) -> Enum1.Maybe () + let () = Runtime_ocaml.Runtime.register_module "Mod_def" [ "S", Obj.repr s; - "half", Obj.repr half_ ] - "todo-module-hash" + "half", Obj.repr half_; + "maybe", Obj.repr maybe_ ] + "CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX" diff --git a/tests/modules/good/prorata_external.ml b/tests/modules/good/prorata_external.ml index 528e1f16..97c8be44 100644 --- a/tests/modules/good/prorata_external.ml +++ b/tests/modules/good/prorata_external.ml @@ -37,4 +37,4 @@ let () = Runtime_ocaml.Runtime.register_module "Prorata_external" [ "prorata", Obj.repr prorata_; "prorata2", Obj.repr prorata2_ ] - "todo-module-hash" + "*external*" diff --git a/tests/name_resolution/good/let_in2.catala_en b/tests/name_resolution/good/let_in2.catala_en index 2253a9ff..826234bd 100644 --- a/tests/name_resolution/good/let_in2.catala_en +++ b/tests/name_resolution/good/let_in2.catala_en @@ -90,5 +90,5 @@ let s (s_in: S_in.t) : S.t = let () = Runtime_ocaml.Runtime.register_module "Let_in2" [ "S", Obj.repr s ] - "todo-module-hash" + "CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX" ```