Implement module hashes and checks (#625)

This commit is contained in:
Louis Gesbert 2024-05-28 12:36:40 +02:00 committed by GitHub
commit 9af7548bf0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 788 additions and 223 deletions

View File

@ -548,8 +548,11 @@ let[@ocamlformat "disable"] static_base_rules =
Nj.rule "out-test" Nj.rule "out-test"
~command: [ ~command: [
!catala_exe; !test_command; "--plugin-dir="; "-o -"; !catala_flags; !catala_exe; !test_command; "--plugin-dir="; "-o -"; !catala_flags;
!input; ">"; !output; "2>&1"; !input; "2>&1";
"||"; "true"; "|"; "sed";
"'s/\"CM0|[a-zA-Z0-9|]*\"/\"CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX\"/g'";
">"; !output;
"||"; "true"
] ]
~description: ~description:
["<catala>"; "test"; !test_id; ""; !input; "(" ^ !test_command ^ ")"]; ["<catala>"; "test"; !test_id; ""; !input; "(" ^ !test_command ^ ")"];

View File

@ -16,10 +16,32 @@
open Catala_utils open Catala_utils
let sanitize =
let re_endtest = Re.(compile @@ seq [bol; str "```"]) in
let re_modhash =
Re.(
compile
@@ seq
[
str "\"CM0|";
repn xdigit 8 (Some 8);
char '|';
repn xdigit 8 (Some 8);
char '|';
repn xdigit 8 (Some 8);
char '"';
])
in
fun str ->
str
|> Re.replace_string re_endtest ~by:"\\```"
|> Re.replace_string re_modhash ~by:"\"CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX\""
let run_catala_test test_flags catala_exe catala_opts file program args oc = let run_catala_test test_flags catala_exe catala_opts file program args oc =
let cmd_in_rd, cmd_in_wr = Unix.pipe () in let cmd_in_rd, cmd_in_wr = Unix.pipe ~cloexec:true () in
Unix.set_close_on_exec cmd_in_wr; let cmd_out_rd, cmd_out_wr = Unix.pipe ~cloexec:true () in
let command_oc = Unix.out_channel_of_descr cmd_in_wr in let command_oc = Unix.out_channel_of_descr cmd_in_wr in
let command_ic = Unix.in_channel_of_descr cmd_out_rd in
let catala_exe = let catala_exe =
(* If the exe name contains directories, make it absolute. Otherwise don't (* If the exe name contains directories, make it absolute. Otherwise don't
modify it so that it can be looked up in PATH. *) modify it so that it can be looked up in PATH. *)
@ -59,12 +81,21 @@ let run_catala_test test_flags catala_exe catala_opts file program args oc =
|> Seq.cons "CATALA_PLUGINS=" |> Seq.cons "CATALA_PLUGINS="
|> Array.of_seq |> Array.of_seq
in in
flush oc; let pid =
let ocfd = Unix.descr_of_out_channel oc in Unix.create_process_env catala_exe cmd env cmd_in_rd cmd_out_wr cmd_out_wr
let pid = Unix.create_process_env catala_exe cmd env cmd_in_rd ocfd ocfd in in
Unix.close cmd_in_rd; Unix.close cmd_in_rd;
Unix.close cmd_out_wr;
Seq.iter (output_string command_oc) program; Seq.iter (output_string command_oc) program;
close_out command_oc; close_out command_oc;
let out_lines =
Seq.of_dispenser (fun () -> In_channel.input_line command_ic)
in
Seq.iter
(fun line ->
output_string oc (sanitize line);
output_char oc '\n')
out_lines;
let return_code = let return_code =
match Unix.waitpid [] pid with match Unix.waitpid [] pid with
| _, Unix.WEXITED n -> n | _, Unix.WEXITED n -> n

View File

@ -0,0 +1,110 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2024 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
type t = int
let mix (h1 : t) (h2 : t) : t = Hashtbl.hash (h1, h2)
let raw = Hashtbl.hash
module Op = struct
let ( % ) = mix
let ( ! ) = raw
end
open Op
let option f = function None -> !`None | Some x -> !`Some % f x
let list hf l = List.fold_left (fun acc x -> acc % hf x) !`ListEmpty l
let map fold_fun kh vh map =
fold_fun (fun k v acc -> acc lxor (kh k % vh v)) map !`HashMapDelim
module Flags : sig
type nonrec t = private t
val pass :
(t -> 'a) ->
avoid_exceptions:bool ->
closure_conversion:bool ->
monomorphize_types:bool ->
'a
val of_t : int -> t
end = struct
type nonrec t = t
let pass k ~avoid_exceptions ~closure_conversion ~monomorphize_types =
(* Should not affect the call convention or actual interfaces: include,
optimize, check_invariants, typed *)
!(avoid_exceptions : bool)
% !(closure_conversion : bool)
% !(monomorphize_types : bool)
% (* The following may not affect the call convention, but we want it set in
an homogeneous way *)
!(Global.options.trace : bool)
% !(Global.options.max_prec_digits : int)
|> k
let of_t t = t
end
type full = { catala_version : t; flags_hash : Flags.t; contents : t }
let finalise t =
Flags.pass (fun flags_hash ->
{ catala_version = !(Version.v : string); flags_hash; contents = t })
let to_string full =
Printf.sprintf "CM0|%08x|%08x|%08x" full.catala_version
(full.flags_hash :> int)
full.contents
(* Putting color inside the hash makes them much easier to differentiate at a
glance *)
let format ppf full =
let open Ocolor_types in
let pcolor col f x =
Format.pp_open_stag ppf Ocolor_format.(Ocolor_style_tag (Fg (C24 col)));
f x;
Format.pp_close_stag ppf ()
in
let tag = pcolor { r24 = 172; g24 = 172; b24 = 172 } in
let auto i =
{
r24 = 128 + (i mod 128);
g24 = 128 + ((i lsr 10) mod 128);
b24 = 128 + ((i lsr 20) mod 128);
}
in
let phash h =
let col = auto h in
pcolor col (Format.fprintf ppf "%08x") h
in
tag (Format.pp_print_string ppf) "CM0|";
phash full.catala_version;
tag (Format.pp_print_string ppf) "|";
phash (full.flags_hash :> int);
tag (Format.pp_print_string ppf) "|";
phash full.contents
let of_string s =
try
Scanf.sscanf s "CM0|%08x|%08x|%08x"
(fun catala_version flags_hash contents ->
{ catala_version; flags_hash = Flags.of_t flags_hash; contents })
with Scanf.Scan_failure _ -> failwith "Hash.of_string"
let external_placeholder = "*external*"

View File

@ -0,0 +1,81 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2024 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** Hashes for the identification of modules.
In contrast with OCaml's basic `Hashtbl.hash`, they process the full depth
of terms. Any meaningful interface change in a module should only be in hash
collision with a 1/2^30 probability. *)
type t = private int
(** Native Hasthbl.hash hashes, value is truncated to 30 bits whatever the
architecture (positive 31-bit integers) *)
type full
(** A "full" hash includes the Catala version and compilation flags, alongside
the module interface *)
val raw : 'a -> t
(** [Hashtbl.hash]. Do not use on deep types (it has a bounded depth), use
specific hashing functions. *)
module Op : sig
val ( ! ) : 'a -> t
(** Shortcut to [raw]. Same warning: use with an explicit type annotation
[!(foo: string)] to ensure it's not called on types that are recursive or
include annotations.
Hint: we use [!`Foo] as a fancy way to generate constants for
discriminating constructions *)
val ( % ) : t -> t -> t
(** Safe combination of two hashes (non commutative or associative, etc.) *)
end
val option : ('a -> t) -> 'a option -> t
val list : ('a -> t) -> 'a list -> t
val map :
(('k -> 'v -> t -> t) -> 'map -> t -> t) ->
('k -> t) ->
('v -> t) ->
'map ->
t
(** [map fold_f key_hash_f value_hash_f map] computes the hash of a map. The
first argument is expected to be a [Foo.Map.fold] function. The result is
independent of the ordering of the map. *)
val finalise :
t ->
avoid_exceptions:bool ->
closure_conversion:bool ->
monomorphize_types:bool ->
full
(** Turns a raw interface hash into a full hash, ready for printing *)
val to_string : full -> string
val format : Format.formatter -> full -> unit
val of_string : string -> full
(** @raise Failure *)
val external_placeholder : string
(** It's inconvenient to need hash updates on external modules. This string is
uses as a hash instead for those cases.
NOTE: This is a temporary solution A future approach could be to have Catala
generate a module loader (with the proper hash), relieving the user
implementation from having to do the registration. *)

View File

@ -29,6 +29,7 @@ let fold f (x, _) = f x
let fold2 f (x, _) (y, _) = f x y let fold2 f (x, _) (y, _) = f x y
let compare cmp a b = fold2 cmp a b let compare cmp a b = fold2 cmp a b
let equal eq a b = fold2 eq a b let equal eq a b = fold2 eq a b
let hash f (x, _) = f x
class ['self] marked_map = class ['self] marked_map =
object (_self : 'self) object (_self : 'self)

View File

@ -41,6 +41,10 @@ val compare : ('a -> 'a -> int) -> ('a, 'm) ed -> ('a, 'm) ed -> int
val equal : ('a -> 'a -> bool) -> ('a, 'm) ed -> ('a, 'm) ed -> bool val equal : ('a -> 'a -> bool) -> ('a, 'm) ed -> ('a, 'm) ed -> bool
(** Tests equality of two marked values {b ignoring marks} *) (** Tests equality of two marked values {b ignoring marks} *)
val hash : ('a -> Hash.t) -> ('a, 'm) ed -> Hash.t
(** Computes the hash of the marked values using the given function
{b ignoring mark} *)
(** Visitors *) (** Visitors *)
class ['self] marked_map : object ('self) class ['self] marked_map : object ('self)

View File

@ -101,6 +101,7 @@ module Arg = struct
end end
let compare = Arg.compare let compare = Arg.compare
let hash t = Hash.raw t
module Set = Set.Make (Arg) module Set = Set.Make (Arg)
module Map = Map.Make (Arg) module Map = Map.Make (Arg)

View File

@ -23,6 +23,8 @@ module Map : Map.S with type key = string
val compare : string -> string -> int val compare : string -> string -> int
(** String comparison with natural ordering of numbers within strings *) (** String comparison with natural ordering of numbers within strings *)
val hash : string -> Hash.t
val to_ascii : string -> string val to_ascii : string -> string
(** Removes all non-ASCII diacritics from a string by converting them to their (** Removes all non-ASCII diacritics from a string by converting them to their
base letter in the Latin alphabet. *) base letter in the Latin alphabet. *)

View File

@ -21,6 +21,7 @@ module type Info = sig
val format : Format.formatter -> info -> unit val format : Format.formatter -> info -> unit
val equal : info -> info -> bool val equal : info -> info -> bool
val compare : info -> info -> int val compare : info -> info -> int
val hash : info -> Hash.t
end end
module type Id = sig module type Id = sig
@ -33,7 +34,8 @@ module type Id = sig
val equal : t -> t -> bool val equal : t -> t -> bool
val format : Format.formatter -> t -> unit val format : Format.formatter -> t -> unit
val to_string : t -> string val to_string : t -> string
val hash : t -> int val id : t -> int
val hash : t -> Hash.t
module Set : Set.S with type elt = t module Set : Set.S with type elt = t
module Map : Map.S with type key = t module Map : Map.S with type key = t
@ -68,8 +70,9 @@ module Make (X : Info) (S : Style) () : Id with type info = X.info = struct
{ id = !counter; info } { id = !counter; info }
let get_info (uid : t) : X.info = uid.info let get_info (uid : t) : X.info = uid.info
let hash (x : t) : int = x.id let id (x : t) : int = x.id
let to_string t = X.to_string t.info let to_string t = X.to_string t.info
let hash t = X.hash t.info
module Set = Set.Make (Ordering) module Set = Set.Make (Ordering)
module Map = Map.Make (Ordering) module Map = Map.Make (Ordering)
@ -84,6 +87,7 @@ module MarkedString = struct
let format fmt i = String.format fmt (to_string i) let format fmt i = String.format fmt (to_string i)
let equal = Mark.equal String.equal let equal = Mark.equal String.equal
let compare = Mark.compare String.compare let compare = Mark.compare String.compare
let hash = Mark.hash String.hash
end end
module Gen (S : Style) () = Make (MarkedString) (S) () module Gen (S : Style) () = Make (MarkedString) (S) ()
@ -109,6 +113,15 @@ module Path = struct
let to_string p = String.concat "." (List.map Module.to_string p) let to_string p = String.concat "." (List.map Module.to_string p)
let equal = List.equal Module.equal let equal = List.equal Module.equal
let compare = List.compare Module.compare let compare = List.compare Module.compare
let strip prefix p0 =
let rec aux prefix p =
match prefix, p with
| pfx1 :: pfx, p1 :: p -> if Module.equal pfx1 p1 then aux pfx p else p0
| [], p -> p
| _ -> p0
in
aux prefix p0
end end
module QualifiedMarkedString = struct module QualifiedMarkedString = struct
@ -125,12 +138,21 @@ module QualifiedMarkedString = struct
let compare (p1, i1) (p2, i2) = let compare (p1, i1) (p2, i2) =
match Path.compare p1 p2 with 0 -> MarkedString.compare i1 i2 | n -> n match Path.compare p1 p2 with 0 -> MarkedString.compare i1 i2 | n -> n
let hash (p, i) =
let open Hash.Op in
Hash.list Module.hash p % MarkedString.hash i
end end
module Gen_qualified (S : Style) () = struct module Gen_qualified (S : Style) () = struct
include Make (QualifiedMarkedString) (S) () include Make (QualifiedMarkedString) (S) ()
let fresh path t = fresh (path, t) let fresh path t = fresh (path, t)
let hash ~strip t =
let p, i = get_info t in
QualifiedMarkedString.hash (Path.strip strip p, i)
let path t = fst (get_info t) let path t = fst (get_info t)
let get_info t = snd (get_info t) let get_info t = snd (get_info t)
end end

View File

@ -28,6 +28,9 @@ module type Info = sig
val compare : info -> info -> int val compare : info -> info -> int
(** Comparison disregards position *) (** Comparison disregards position *)
val hash : info -> Hash.t
(** Hashing disregards position *)
end end
module MarkedString : Info with type info = string Mark.pos module MarkedString : Info with type info = string Mark.pos
@ -48,7 +51,15 @@ module type Id = sig
val equal : t -> t -> bool val equal : t -> t -> bool
val format : Format.formatter -> t -> unit val format : Format.formatter -> t -> unit
val to_string : t -> string val to_string : t -> string
val hash : t -> int
val id : t -> int
(** Returns the unique ID of the identifier *)
val hash : t -> Hash.t
(** While [id] returns a unique ID valable for a given Uid instance within a
given run of catala, this is a raw hash of the identifier string.
Therefore, it may collide within a given program, but remains meaninful
across separate compilations. *)
module Set : Set.S with type elt = t module Set : Set.S with type elt = t
module Map : Map.S with type key = t module Map : Map.S with type key = t
@ -79,6 +90,10 @@ module Path : sig
val format : Format.formatter -> t -> unit val format : Format.formatter -> t -> unit
val equal : t -> t -> bool val equal : t -> t -> bool
val compare : t -> t -> int val compare : t -> t -> int
val strip : t -> t -> t
(** [strip pfx p] removed [pfx] from the start of [p]. if [p] doesn't start
with [pfx], it is returned unchanged *)
end end
(** Same as [Gen] but also registers path information *) (** Same as [Gen] but also registers path information *)
@ -88,4 +103,6 @@ module Gen_qualified (_ : Style) () : sig
val fresh : Path.t -> MarkedString.info -> t val fresh : Path.t -> MarkedString.info -> t
val path : t -> Path.t val path : t -> Path.t
val get_info : t -> MarkedString.info val get_info : t -> MarkedString.info
val hash : strip:Path.t -> t -> Hash.t
(* [strip] strips that prefix from the start of the path before hashing *)
end end

View File

@ -71,12 +71,17 @@ module ScopeDef = struct
ScopeVar.format ppf (Mark.remove v); ScopeVar.format ppf (Mark.remove v);
format_kind ppf k format_kind ppf k
let hash_kind = function open Hash.Op
| Var None -> 0
| Var (Some st) -> StateName.hash st
| SubScopeInput { var_within_origin_scope = v; _ } -> ScopeVar.hash v
let hash (v, k) = Int.logxor (ScopeVar.hash (Mark.remove v)) (hash_kind k) let hash_kind ~strip = function
| Var v -> !`Var % Hash.option StateName.hash v
| SubScopeInput { name; var_within_origin_scope } ->
!`SubScopeInput
% ScopeName.hash ~strip name
% ScopeVar.hash var_within_origin_scope
let hash ~strip (v, k) =
Hash.Op.(ScopeVar.hash (Mark.remove v) % hash_kind ~strip k)
end end
include Base include Base
@ -231,6 +236,8 @@ type scope_def = {
type var_or_states = WholeVar | States of StateName.t list type var_or_states = WholeVar | States of StateName.t list
(* If fields are added, make sure to consider including them in the hash
computations below *)
type scope = { type scope = {
scope_vars : var_or_states ScopeVar.Map.t; scope_vars : var_or_states ScopeVar.Map.t;
scope_sub_scopes : ScopeName.t ScopeVar.Map.t; scope_sub_scopes : ScopeName.t ScopeVar.Map.t;
@ -239,21 +246,76 @@ type scope = {
scope_assertions : assertion AssertionName.Map.t; scope_assertions : assertion AssertionName.Map.t;
scope_options : catala_option Mark.pos list; scope_options : catala_option Mark.pos list;
scope_meta_assertions : meta_assertion list; scope_meta_assertions : meta_assertion list;
scope_visibility : visibility;
}
type topdef = {
topdef_expr : expr option;
topdef_type : typ;
topdef_visibility : visibility;
} }
type modul = { type modul = {
module_scopes : scope ScopeName.Map.t; module_scopes : scope ScopeName.Map.t;
module_topdefs : (expr option * typ) TopdefName.Map.t; module_topdefs : topdef TopdefName.Map.t;
} }
type program = { type program = {
program_module_name : Ident.t Mark.pos option; program_module_name : (ModuleName.t * module_intf_id) option;
program_ctx : decl_ctx; program_ctx : decl_ctx;
program_modules : modul ModuleName.Map.t; program_modules : modul ModuleName.Map.t;
program_root : modul; program_root : modul;
program_lang : Global.backend_lang; program_lang : Global.backend_lang;
} }
module Hash = struct
open Hash.Op
let var_or_state = function
| WholeVar -> !`WholeVar
| States s -> !`States % Hash.list StateName.hash s
let io x =
!(Mark.remove x.io_input : Runtime.io_input)
% !(Mark.remove x.io_output : bool)
let scope_decl ~strip d =
(* scope_def_rules is ignored (not part of the interface) *)
Type.hash ~strip d.scope_def_typ
% Hash.option
(fun (lst, _) ->
List.fold_left
(fun acc (name, ty) ->
acc % Uid.MarkedString.hash name % Type.hash ~strip ty)
!`SDparams lst)
d.scope_def_parameters
% !(d.scope_def_is_condition : bool)
% io d.scope_def_io
let scope ~strip s =
Hash.map ScopeVar.Map.fold ScopeVar.hash var_or_state s.scope_vars
% Hash.map ScopeVar.Map.fold ScopeVar.hash (ScopeName.hash ~strip)
s.scope_sub_scopes
% ScopeName.hash ~strip s.scope_uid
% Hash.map ScopeDef.Map.fold (ScopeDef.hash ~strip) (scope_decl ~strip)
s.scope_defs
(* assertions, options, etc. are not expected to be part of interfaces *)
let modul ?(strip = []) m =
Hash.map ScopeName.Map.fold (ScopeName.hash ~strip) (scope ~strip)
(ScopeName.Map.filter
(fun _ s -> s.scope_visibility = Public)
m.module_scopes)
% Hash.map TopdefName.Map.fold (TopdefName.hash ~strip)
(fun td -> Type.hash ~strip td.topdef_type)
(TopdefName.Map.filter
(fun _ td -> td.topdef_visibility = Public)
m.module_topdefs)
let module_binding modname m =
ModuleName.hash modname % modul ~strip:[modname] m
end
let rec locations_used e : LocationSet.t = let rec locations_used e : LocationSet.t =
match e with match e with
| ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m) | ELocation l, m -> LocationSet.singleton (l, Expr.mark_pos m)
@ -311,5 +373,5 @@ let fold_exprs ~(f : 'a -> expr -> 'a) ~(init : 'a) (p : program) : 'a =
p.program_root.module_scopes init p.program_root.module_scopes init
in in
TopdefName.Map.fold TopdefName.Map.fold
(fun _ (e, _) acc -> Option.fold ~none:acc ~some:(f acc) e) (fun _ tdef acc -> Option.fold ~none:acc ~some:(f acc) tdef.topdef_expr)
p.program_root.module_topdefs acc p.program_root.module_topdefs acc

View File

@ -32,7 +32,7 @@ module ScopeDef : sig
val equal_kind : kind -> kind -> bool val equal_kind : kind -> kind -> bool
val compare_kind : kind -> kind -> int val compare_kind : kind -> kind -> int
val format_kind : Format.formatter -> kind -> unit val format_kind : Format.formatter -> kind -> unit
val hash_kind : kind -> int val hash_kind : strip:Uid.Path.t -> kind -> Hash.t
type t = ScopeVar.t Mark.pos * kind type t = ScopeVar.t Mark.pos * kind
@ -40,7 +40,7 @@ module ScopeDef : sig
val compare : t -> t -> int val compare : t -> t -> int
val get_position : t -> Pos.t val get_position : t -> Pos.t
val format : Format.formatter -> t -> unit val format : Format.formatter -> t -> unit
val hash : t -> int val hash : strip:Uid.Path.t -> t -> Hash.t
module Map : Map.S with type key = t module Map : Map.S with type key = t
module Set : Set.S with type elt = t module Set : Set.S with type elt = t
@ -123,16 +123,23 @@ type scope = {
(** empty outside of the root module *) (** empty outside of the root module *)
scope_options : catala_option Mark.pos list; scope_options : catala_option Mark.pos list;
scope_meta_assertions : meta_assertion list; scope_meta_assertions : meta_assertion list;
scope_visibility : visibility;
}
type topdef = {
topdef_expr : expr option; (** Always [None] outside of the root module *)
topdef_type : typ;
topdef_visibility : visibility;
(** Necessarily [Public] outside of the root module *)
} }
type modul = { type modul = {
module_scopes : scope ScopeName.Map.t; module_scopes : scope ScopeName.Map.t;
module_topdefs : (expr option * typ) TopdefName.Map.t; module_topdefs : topdef TopdefName.Map.t;
(** the expr is [None] outside of the root module *)
} }
type program = { type program = {
program_module_name : Ident.t Mark.pos option; program_module_name : (ModuleName.t * module_intf_id) option;
program_ctx : decl_ctx; program_ctx : decl_ctx;
program_modules : modul ModuleName.Map.t; program_modules : modul ModuleName.Map.t;
(** Contains all submodules of the program, in a flattened structure *) (** Contains all submodules of the program, in a flattened structure *)
@ -140,6 +147,18 @@ type program = {
program_lang : Global.backend_lang; program_lang : Global.backend_lang;
} }
(** {1 Interface hash computations} *)
(** These hashes are computed on interfaces: only signatures are considered. *)
module Hash : sig
(** The [strip] argument below strips as many leading path components before
hashing *)
val scope : strip:Uid.Path.t -> scope -> Hash.t
val modul : ?strip:Uid.Path.t -> modul -> Hash.t
val module_binding : ModuleName.t -> modul -> Hash.t
end
(** {1 Helpers} *) (** {1 Helpers} *)
val locations_used : expr -> LocationSet.t val locations_used : expr -> LocationSet.t

View File

@ -39,9 +39,9 @@ module Vertex = struct
let hash x = let hash x =
match x with match x with
| Var (x, None) -> ScopeVar.hash x | Var (x, None) -> ScopeVar.id x
| Var (x, Some sx) -> Int.logxor (ScopeVar.hash x) (StateName.hash sx) | Var (x, Some sx) -> Hashtbl.hash (ScopeVar.id x, StateName.id sx)
| Assertion a -> Ast.AssertionName.hash a | Assertion a -> Hashtbl.hash (`Assert (Ast.AssertionName.id a))
let compare x y = let compare x y =
match x, y with match x, y with
@ -252,7 +252,7 @@ module ExceptionVertex = struct
let hash (x : t) : int = let hash (x : t) : int =
RuleName.Map.fold RuleName.Map.fold
(fun r _ acc -> Int.logxor (RuleName.hash r) acc) (fun r _ acc -> Hashtbl.hash (RuleName.id r, acc))
x.rules 0 x.rules 0
let equal x y = compare x y = 0 let equal x y = compare x y = 0

View File

@ -98,10 +98,14 @@ let program prg =
in in
let module_topdefs = let module_topdefs =
TopdefName.Map.map TopdefName.Map.map
(function (fun def ->
| Some e, ty -> {
Some (Expr.unbox (expr prg.program_ctx env (Expr.box e))), ty def with
| None, ty -> None, ty) topdef_expr =
Option.map
(fun e -> Expr.unbox (expr prg.program_ctx env (Expr.box e)))
def.topdef_expr;
})
prg.program_root.module_topdefs prg.program_root.module_topdefs
in in
let module_scopes = let module_scopes =

View File

@ -313,7 +313,7 @@ let rec translate_expr
in in
let e2 = rec_helper ~local_vars e2 in let e2 = rec_helper ~local_vars e2 in
Expr.make_abs [| binding_var |] e2 [tau] pos_op) Expr.make_abs [| binding_var |] e2 [tau] pos_op)
(EnumName.Map.find enum_uid ctxt.enums) (fst (EnumName.Map.find enum_uid ctxt.enums))
in in
Expr.ematch ~e:(rec_helper e1_sub) ~name:enum_uid ~cases emark Expr.ematch ~e:(rec_helper e1_sub) ~name:enum_uid ~cases emark
| Binop ((((S.And | S.Or | S.Xor), _) as op), e1, e2) -> | Binop ((((S.And | S.Or | S.Xor), _) as op), e1, e2) ->
@ -613,7 +613,7 @@ let rec translate_expr
StructField.Map.add f_uid f_e s_fields) StructField.Map.add f_uid f_e s_fields)
StructField.Map.empty fields StructField.Map.empty fields
in in
let expected_s_fields = StructName.Map.find s_uid ctxt.structs in let expected_s_fields, _ = StructName.Map.find s_uid ctxt.structs in
if if
StructField.Map.exists StructField.Map.exists
(fun expected_f _ -> not (StructField.Map.mem expected_f s_fields)) (fun expected_f _ -> not (StructField.Map.mem expected_f s_fields))
@ -719,7 +719,7 @@ let rec translate_expr
Expr.make_abs [| nop_var |] Expr.make_abs [| nop_var |]
(Expr.elit (LBool (EnumConstructor.compare c_uid c_uid' = 0)) emark) (Expr.elit (LBool (EnumConstructor.compare c_uid c_uid' = 0)) emark)
[tau] pos) [tau] pos)
(EnumName.Map.find enum_uid ctxt.enums) (fst (EnumName.Map.find enum_uid ctxt.enums))
in in
Expr.ematch ~e:(rec_helper e1) ~name:enum_uid ~cases emark Expr.ematch ~e:(rec_helper e1) ~name:enum_uid ~cases emark
| ArrayLit es -> Expr.earray (List.map rec_helper es) emark | ArrayLit es -> Expr.earray (List.map rec_helper es) emark
@ -970,7 +970,7 @@ and disambiguate_match_and_build_expression
Expr.eabs e_binder Expr.eabs e_binder
[ [
EnumConstructor.Map.find c_uid EnumConstructor.Map.find c_uid
(EnumName.Map.find e_uid ctxt.Name_resolution.enums); (fst (EnumName.Map.find e_uid ctxt.Name_resolution.enums));
] ]
(Mark.get case_body) (Mark.get case_body)
in in
@ -1037,6 +1037,7 @@ and disambiguate_match_and_build_expression
if curr_index < nb_cases - 1 then raise_wildcard_not_last_case_err (); if curr_index < nb_cases - 1 then raise_wildcard_not_last_case_err ();
let missing_constructors = let missing_constructors =
EnumName.Map.find e_uid ctxt.Name_resolution.enums EnumName.Map.find e_uid ctxt.Name_resolution.enums
|> fst
|> EnumConstructor.Map.filter_map (fun c_uid _ -> |> EnumConstructor.Map.filter_map (fun c_uid _ ->
match EnumConstructor.Map.find_opt c_uid cases_d with match EnumConstructor.Map.find_opt c_uid cases_d with
| Some _ -> None | Some _ -> None
@ -1451,6 +1452,7 @@ let process_scope_use
let process_topdef let process_topdef
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(prgm : Ast.program) (prgm : Ast.program)
(is_public : bool)
(def : S.top_def) : Ast.program = (def : S.top_def) : Ast.program =
let id = let id =
Ident.Map.find Ident.Map.find
@ -1493,12 +1495,14 @@ let process_topdef
in in
Some (Expr.unbox_closed e) Some (Expr.unbox_closed e)
in in
let topdef_visibility = if is_public then Public else Private in
let module_topdefs = let module_topdefs =
TopdefName.Map.update id TopdefName.Map.update id
(fun def0 -> (fun def0 ->
match def0, expr_opt with match def0, expr_opt with
| None, eopt -> Some (eopt, typ) | None, eopt ->
| Some (eopt0, ty0), eopt -> ( Some { Ast.topdef_expr = eopt; topdef_visibility; topdef_type = typ }
| Some def0, eopt -> (
let err msg = let err msg =
Message.error Message.error
~extra_pos: ~extra_pos:
@ -1508,13 +1512,16 @@ let process_topdef
] ]
(msg ^^ " for %a") TopdefName.format id (msg ^^ " for %a") TopdefName.format id
in in
if not (Type.equal ty0 typ) then err "Conflicting type definitions" if not (Type.equal def0.Ast.topdef_type typ) then
err "Conflicting type definitions"
else else
match eopt0, eopt with match def0.Ast.topdef_expr, eopt with
| None, None -> err "Multiple declarations" | None, None -> err "Multiple declarations"
| Some _, Some _ -> err "Multiple definitions" | Some _, Some _ -> err "Multiple definitions"
| Some e, None -> Some (Some e, typ) | (Some _ as topdef_expr), None ->
| None, Some e -> Some (Some e, ty0))) Some { Ast.topdef_expr; topdef_visibility; topdef_type = typ }
| None, (Some _ as topdef_expr) ->
Some { def0 with Ast.topdef_expr }))
prgm.Ast.program_root.module_topdefs prgm.Ast.program_root.module_topdefs
in in
{ prgm with program_root = { prgm.program_root with module_topdefs } } { prgm with program_root = { prgm.program_root with module_topdefs } }
@ -1680,6 +1687,7 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
scope_meta_assertions = []; scope_meta_assertions = [];
scope_options = []; scope_options = [];
scope_uid = s_uid; scope_uid = s_uid;
scope_visibility = s_context.Name_resolution.scope_visibility;
} }
in in
let get_scopes mctx = let get_scopes mctx =
@ -1692,22 +1700,32 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
mctx.Name_resolution.typedefs ScopeName.Map.empty mctx.Name_resolution.typedefs ScopeName.Map.empty
in in
let program_modules = let program_modules =
ModuleName.Map.map ModuleName.Map.mapi
(fun mctx -> (fun mname mctx ->
let m =
{ {
Ast.module_scopes = get_scopes mctx; Ast.module_scopes = get_scopes mctx;
Ast.module_topdefs = Ast.module_topdefs =
Ident.Map.fold Ident.Map.fold
(fun _ name acc -> (fun _ name acc ->
let topdef_type, topdef_visibility =
TopdefName.Map.find name ctxt.Name_resolution.topdefs
in
TopdefName.Map.add name TopdefName.Map.add name
( None, { Ast.topdef_expr = None; topdef_visibility; topdef_type }
TopdefName.Map.find name ctxt.Name_resolution.topdef_types
)
acc) acc)
mctx.topdefs TopdefName.Map.empty; mctx.topdefs TopdefName.Map.empty;
}) }
in
m, Ast.Hash.module_binding mname m)
ctxt.modules ctxt.modules
in in
let program_root =
{
Ast.module_scopes = get_scopes ctxt.Name_resolution.local;
Ast.module_topdefs = TopdefName.Map.empty;
}
in
let program_ctx = let program_ctx =
let open Name_resolution in let open Name_resolution in
let ctx_scopes mctx acc = let ctx_scopes mctx acc =
@ -1720,23 +1738,30 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
in in
let ctx_modules = let ctx_modules =
let rec aux mctx = let rec aux mctx =
let subs =
Ident.Map.fold Ident.Map.fold
(fun _ m (M acc) -> (fun _ m acc ->
let sub = aux (ModuleName.Map.find m ctxt.modules) in let mctx = ModuleName.Map.find m ctxt.Name_resolution.modules in
M (ModuleName.Map.add m sub acc)) let deps = aux mctx in
mctx.used_modules (M ModuleName.Map.empty) let hash = snd (ModuleName.Map.find m program_modules) in
ModuleName.Map.add m
{ deps; intf_id = { hash; is_external = mctx.is_external } }
acc)
mctx.used_modules ModuleName.Map.empty
in
subs
in in
aux ctxt.local aux ctxt.local
in in
{ {
ctx_structs = ctxt.structs; ctx_structs = StructName.Map.map fst ctxt.structs;
ctx_enums = ctxt.enums; ctx_enums = EnumName.Map.map fst ctxt.enums;
ctx_scopes = ctx_scopes =
ModuleName.Map.fold ModuleName.Map.fold
(fun _ -> ctx_scopes) (fun _ -> ctx_scopes)
ctxt.modules ctxt.modules
(ctx_scopes ctxt.local ScopeName.Map.empty); (ctx_scopes ctxt.local ScopeName.Map.empty);
ctx_topdefs = ctxt.topdef_types; ctx_topdefs = TopdefName.Map.map fst ctxt.topdefs;
ctx_struct_fields = ctxt.local.field_idmap; ctx_struct_fields = ctxt.local.field_idmap;
ctx_enum_constrs = ctxt.local.constructor_idmap; ctx_enum_constrs = ctxt.local.constructor_idmap;
ctx_scope_index = ctx_scope_index =
@ -1748,25 +1773,29 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
ctx_modules; ctx_modules;
} }
in in
let program_module_name =
surface.Surface.Ast.program_module
|> Option.map
@@ fun { Surface.Ast.module_name; module_external } ->
let mname = ModuleName.fresh module_name in
let hash_placeholder = Hash.raw 0 in
mname, { hash = hash_placeholder; is_external = module_external }
in
let desugared = let desugared =
{ {
Ast.program_lang = surface.program_lang; Ast.program_lang = surface.program_lang;
Ast.program_module_name = surface.Surface.Ast.program_module_name; Ast.program_module_name;
Ast.program_modules; Ast.program_modules = ModuleName.Map.map fst program_modules;
Ast.program_ctx; Ast.program_ctx;
Ast.program_root = Ast.program_root;
{
Ast.module_scopes = get_scopes ctxt.Name_resolution.local;
Ast.module_topdefs = TopdefName.Map.empty;
};
} }
in in
let process_code_block ctxt prgm block = let process_code_block ctxt prgm is_meta block =
List.fold_left List.fold_left
(fun prgm item -> (fun prgm item ->
match Mark.remove item with match Mark.remove item with
| S.ScopeUse use -> process_scope_use ctxt prgm use | S.ScopeUse use -> process_scope_use ctxt prgm use
| S.Topdef def -> process_topdef ctxt prgm def | S.Topdef def -> process_topdef ctxt prgm is_meta def
| S.ScopeDecl _ | S.StructDecl _ | S.EnumDecl _ -> prgm) | S.ScopeDecl _ | S.StructDecl _ | S.EnumDecl _ -> prgm)
prgm block prgm block
in in
@ -1777,7 +1806,22 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
List.fold_left List.fold_left
(fun prgm child -> process_structure prgm child) (fun prgm child -> process_structure prgm child)
prgm children prgm children
| S.CodeBlock (block, _, _) -> process_code_block ctxt prgm block | S.CodeBlock (block, _, is_meta) ->
process_code_block ctxt prgm is_meta block
| S.ModuleDef _ | S.LawInclude _ | S.LawText _ | S.ModuleUse _ -> prgm | S.ModuleDef _ | S.LawInclude _ | S.LawText _ | S.ModuleUse _ -> prgm
in in
let desugared =
List.fold_left process_structure desugared surface.S.program_items List.fold_left process_structure desugared surface.S.program_items
in
{
desugared with
Ast.program_module_name =
(desugared.Ast.program_module_name
|> Option.map
@@ fun (mname, intf_id) ->
( mname,
{
intf_id with
hash = Ast.Hash.module_binding mname desugared.Ast.program_root;
} ));
}

View File

@ -39,6 +39,7 @@ type scope_context = {
scope_out_struct : StructName.t; scope_out_struct : StructName.t;
sub_scopes : ScopeName.Set.t; sub_scopes : ScopeName.Set.t;
(** Other scopes referred to by this scope. Used for dependency analysis *) (** Other scopes referred to by this scope. Used for dependency analysis *)
scope_visibility : visibility;
} }
(** Inside a scope, we distinguish between the variables and the subscopes. *) (** Inside a scope, we distinguish between the variables and the subscopes. *)
@ -77,15 +78,17 @@ type module_context = {
between different enums *) between different enums *)
topdefs : TopdefName.t Ident.Map.t; (** Global definitions *) topdefs : TopdefName.t Ident.Map.t; (** Global definitions *)
used_modules : ModuleName.t Ident.Map.t; used_modules : ModuleName.t Ident.Map.t;
is_external : bool;
} }
(** Context for name resolution, valid within a given module *) (** Context for name resolution, valid within a given module *)
type context = { type context = {
scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
topdef_types : typ TopdefName.Map.t; topdefs : (typ * visibility) TopdefName.Map.t;
structs : struct_context StructName.Map.t; structs : (struct_context * visibility) StructName.Map.t;
(** For each struct, its context *) (** For each struct, its context *)
enums : enum_context EnumName.Map.t; (** For each enum, its context *) enums : (enum_context * visibility) EnumName.Map.t;
(** For each enum, its context *)
var_typs : var_sig ScopeVar.Map.t; var_typs : var_sig ScopeVar.Map.t;
(** The signatures of each scope variable declared *) (** The signatures of each scope variable declared *)
modules : module_context ModuleName.Map.t; modules : module_context ModuleName.Map.t;
@ -426,8 +429,10 @@ let process_data_decl
} }
(** Process a struct declaration *) (** Process a struct declaration *)
let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : let process_struct_decl
context = ?(visibility = Public)
(ctxt : context)
(sdecl : Surface.Ast.struct_decl) : context =
let s_uid = get_struct ctxt sdecl.struct_decl_name in let s_uid = get_struct ctxt sdecl.struct_decl_name in
if sdecl.struct_decl_fields = [] then if sdecl.struct_decl_fields = [] then
Message.error Message.error
@ -454,25 +459,28 @@ let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) :
let ctxt = { ctxt with local } in let ctxt = { ctxt with local } in
let structs = let structs =
StructName.Map.update s_uid StructName.Map.update s_uid
(fun fields -> (function
match fields with
| None -> | None ->
Some Some
( StructField.Map.singleton f_uid ( StructField.Map.singleton f_uid
(process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)) (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ),
| Some fields -> visibility )
| Some (fields, _) ->
Some Some
( StructField.Map.add f_uid ( StructField.Map.add f_uid
(process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ) (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)
fields)) fields,
visibility ))
ctxt.structs ctxt.structs
in in
{ ctxt with structs }) { ctxt with structs })
ctxt sdecl.struct_decl_fields ctxt sdecl.struct_decl_fields
(** Process an enum declaration *) (** Process an enum declaration *)
let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context let process_enum_decl
= ?(visibility = Public)
(ctxt : context)
(edecl : Surface.Ast.enum_decl) : context =
let e_uid = get_enum ctxt edecl.enum_decl_name in let e_uid = get_enum ctxt edecl.enum_decl_name in
if List.length edecl.enum_decl_cases = 0 then if List.length edecl.enum_decl_cases = 0 then
Message.error Message.error
@ -506,23 +514,24 @@ let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context
| Some typ -> process_type ctxt typ | Some typ -> process_type ctxt typ
in in
match cases with match cases with
| None -> Some (EnumConstructor.Map.singleton c_uid typ) | None -> Some (EnumConstructor.Map.singleton c_uid typ, visibility)
| Some fields -> Some (EnumConstructor.Map.add c_uid typ fields)) | Some (fields, _) ->
Some (EnumConstructor.Map.add c_uid typ fields, visibility))
ctxt.enums ctxt.enums
in in
{ ctxt with enums }) { ctxt with enums })
ctxt edecl.enum_decl_cases ctxt edecl.enum_decl_cases
let process_topdef ctxt def = let process_topdef ?(visibility = Public) ctxt def =
let uid = let uid =
Ident.Map.find (Mark.remove def.Surface.Ast.topdef_name) ctxt.local.topdefs Ident.Map.find (Mark.remove def.Surface.Ast.topdef_name) ctxt.local.topdefs
in in
{ {
ctxt with ctxt with
topdef_types = topdefs =
TopdefName.Map.add uid TopdefName.Map.add uid
(process_type ctxt def.Surface.Ast.topdef_type) (process_type ctxt def.Surface.Ast.topdef_type, visibility)
ctxt.topdef_types; ctxt.topdefs;
} }
(** Process an item declaration *) (** Process an item declaration *)
@ -536,8 +545,10 @@ let process_item_decl
process_subscope_decl scope ctxt sub_decl process_subscope_decl scope ctxt sub_decl
(** Process a scope declaration *) (** Process a scope declaration *)
let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : let process_scope_decl
context = ?(visibility = Public)
(ctxt : context)
(decl : Surface.Ast.scope_decl) : context =
let scope_uid = get_scope ctxt decl.scope_decl_name in let scope_uid = get_scope ctxt decl.scope_decl_name in
let ctxt = let ctxt =
List.fold_left List.fold_left
@ -588,11 +599,12 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) :
structs = structs =
StructName.Map.add StructName.Map.add
(get_struct ctxt decl.scope_decl_name) (get_struct ctxt decl.scope_decl_name)
StructField.Map.empty ctxt.structs; (StructField.Map.empty, visibility)
ctxt.structs;
} }
else else
let ctxt = let ctxt =
process_struct_decl ctxt process_struct_decl ~visibility ctxt
{ {
struct_decl_name = decl.scope_decl_name; struct_decl_name = decl.scope_decl_name;
struct_decl_fields = output_fields; struct_decl_fields = output_fields;
@ -634,8 +646,10 @@ let typedef_info = function
| TScope (s, _) -> ScopeName.get_info s | TScope (s, _) -> ScopeName.get_info s
(** Process the names of all declaration items *) (** Process the names of all declaration items *)
let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : let process_name_item
context = ?(visibility = Public)
(ctxt : context)
(item : Surface.Ast.code_item Mark.pos) : context =
let raise_already_defined_error (use : Uid.MarkedString.info) name pos msg = let raise_already_defined_error (use : Uid.MarkedString.info) name pos msg =
Message.error Message.error
~fmt_pos: ~fmt_pos:
@ -676,6 +690,7 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) :
scope_in_struct = in_struct_name; scope_in_struct = in_struct_name;
scope_out_struct = out_struct_name; scope_out_struct = out_struct_name;
sub_scopes = ScopeName.Set.empty; sub_scopes = ScopeName.Set.empty;
scope_visibility = visibility;
} }
ctxt.scopes ctxt.scopes
in in
@ -720,14 +735,16 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) :
{ ctxt with local = { ctxt.local with topdefs } } { ctxt with local = { ctxt.local with topdefs } }
(** Process a code item that is a declaration *) (** Process a code item that is a declaration *)
let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : let process_decl_item
context = ?visibility
(ctxt : context)
(item : Surface.Ast.code_item Mark.pos) : context =
match Mark.remove item with match Mark.remove item with
| ScopeDecl decl -> process_scope_decl ctxt decl | ScopeDecl decl -> process_scope_decl ?visibility ctxt decl
| StructDecl sdecl -> process_struct_decl ctxt sdecl | StructDecl sdecl -> process_struct_decl ?visibility ctxt sdecl
| EnumDecl edecl -> process_enum_decl ctxt edecl | EnumDecl edecl -> process_enum_decl ?visibility ctxt edecl
| ScopeUse _ -> ctxt | ScopeUse _ -> ctxt
| Topdef def -> process_topdef ctxt def | Topdef def -> process_topdef ?visibility ctxt def
(** Process a code block *) (** Process a code block *)
let process_code_block let process_code_block
@ -738,7 +755,11 @@ let process_code_block
(** Process a law structure, only considering the code blocks *) (** Process a law structure, only considering the code blocks *)
let rec process_law_structure let rec process_law_structure
(process_item : context -> Surface.Ast.code_item Mark.pos -> context) (process_item :
?visibility:visibility ->
context ->
Surface.Ast.code_item Mark.pos ->
context)
(ctxt : context) (ctxt : context)
(s : Surface.Ast.law_structure) : context = (s : Surface.Ast.law_structure) : context =
match s with match s with
@ -746,10 +767,14 @@ let rec process_law_structure
List.fold_left List.fold_left
(fun ctxt child -> process_law_structure process_item ctxt child) (fun ctxt child -> process_law_structure process_item ctxt child)
ctxt children ctxt children
| Surface.Ast.CodeBlock (block, _, _) -> | Surface.Ast.CodeBlock (block, _, is_meta) ->
process_code_block process_item ctxt block process_code_block
(process_item ~visibility:(if is_meta then Public else Private))
ctxt block
| Surface.Ast.ModuleDef (_, is_external) ->
{ ctxt with local = { ctxt.local with is_external } }
| Surface.Ast.LawInclude _ | Surface.Ast.LawText _ -> ctxt | Surface.Ast.LawInclude _ | Surface.Ast.LawText _ -> ctxt
| Surface.Ast.ModuleDef _ | Surface.Ast.ModuleUse _ -> ctxt | Surface.Ast.ModuleUse _ -> ctxt
(** {1 Scope uses pass} *) (** {1 Scope uses pass} *)
@ -957,12 +982,13 @@ let empty_module_ctxt =
constructor_idmap = Ident.Map.empty; constructor_idmap = Ident.Map.empty;
topdefs = Ident.Map.empty; topdefs = Ident.Map.empty;
used_modules = Ident.Map.empty; used_modules = Ident.Map.empty;
is_external = false;
} }
let empty_ctxt = let empty_ctxt =
{ {
scopes = ScopeName.Map.empty; scopes = ScopeName.Map.empty;
topdef_types = TopdefName.Map.empty; topdefs = TopdefName.Map.empty;
var_typs = ScopeVar.Map.empty; var_typs = ScopeVar.Map.empty;
structs = StructName.Map.empty; structs = StructName.Map.empty;
enums = EnumName.Map.empty; enums = EnumName.Map.empty;
@ -985,7 +1011,13 @@ let form_context (surface, mod_uses) surface_modules : context =
let ctxt = let ctxt =
{ {
ctxt with ctxt with
local = { ctxt.local with used_modules = mod_uses; path = [m] }; local =
{
ctxt.local with
used_modules = mod_uses;
path = [m];
is_external = intf.Surface.Ast.intf_modname.module_external;
};
} }
in in
let ctxt = let ctxt =
@ -1017,7 +1049,7 @@ let form_context (surface, mod_uses) surface_modules : context =
in in
let ctxt = let ctxt =
List.fold_left List.fold_left
(process_law_structure process_use_item) (process_law_structure (fun ?visibility:_ -> process_use_item))
ctxt surface.Surface.Ast.program_items ctxt surface.Surface.Ast.program_items
in in
(* Gather struct fields and enum constrs from direct modules: this helps with (* Gather struct fields and enum constrs from direct modules: this helps with

View File

@ -39,6 +39,7 @@ type scope_context = {
scope_out_struct : StructName.t; scope_out_struct : StructName.t;
sub_scopes : ScopeName.Set.t; sub_scopes : ScopeName.Set.t;
(** Other scopes referred to by this scope. Used for dependency analysis *) (** Other scopes referred to by this scope. Used for dependency analysis *)
scope_visibility : visibility;
} }
(** Inside a scope, we distinguish between the variables and the subscopes. *) (** Inside a scope, we distinguish between the variables and the subscopes. *)
@ -82,16 +83,18 @@ type module_context = {
topdefs : TopdefName.t Ident.Map.t; (** Global definitions *) topdefs : TopdefName.t Ident.Map.t; (** Global definitions *)
used_modules : ModuleName.t Ident.Map.t; used_modules : ModuleName.t Ident.Map.t;
(** Module aliases and the modules they point to *) (** Module aliases and the modules they point to *)
is_external : bool;
} }
(** Context for name resolution, valid within a given module *) (** Context for name resolution, valid within a given module *)
type context = { type context = {
scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
topdef_types : typ TopdefName.Map.t; topdefs : (typ * visibility) TopdefName.Map.t;
(** Types associated with the global definitions *) (** Types associated with the global definitions *)
structs : struct_context StructName.Map.t; structs : (struct_context * visibility) StructName.Map.t;
(** For each struct, its context *) (** For each struct, its context *)
enums : enum_context EnumName.Map.t; (** For each enum, its context *) enums : (enum_context * visibility) EnumName.Map.t;
(** For each enum, its context *)
var_typs : var_sig ScopeVar.Map.t; var_typs : var_sig ScopeVar.Map.t;
(** The signatures of each scope variable declared *) (** The signatures of each scope variable declared *)
modules : module_context ModuleName.Map.t; modules : module_context ModuleName.Map.t;

View File

@ -93,7 +93,7 @@ let load_module_interfaces
Surface.Parser_driver.load_interface ?default_module_name Surface.Parser_driver.load_interface ?default_module_name
(Global.FileName f) (Global.FileName f)
in in
let modname = ModuleName.fresh intf.intf_modname in let modname = ModuleName.fresh intf.intf_modname.module_name in
let seen = File.Map.add f None seen in let seen = File.Map.add f None seen in
let seen, sub_use_map = let seen, sub_use_map =
aux aux
@ -107,9 +107,9 @@ let load_module_interfaces
(seen, Ident.Map.empty) uses (seen, Ident.Map.empty) uses
in in
let seen = let seen =
match program.Surface.Ast.program_module_name with match program.Surface.Ast.program_module with
| Some m -> | Some m ->
let file = Pos.get_file (Mark.get m) in let file = Pos.get_file (Mark.get m.module_name) in
File.Map.singleton file None File.Map.singleton file None
| None -> File.Map.empty | None -> File.Map.empty
in in
@ -712,7 +712,12 @@ module Commands = struct
let prg, _ = let prg, _ =
Passes.dcalc options ~includes ~optimize ~check_invariants ~typed Passes.dcalc options ~includes ~optimize ~check_invariants ~typed
in in
Interpreter.load_runtime_modules prg; Interpreter.load_runtime_modules
~hashf:
Hash.(
finalise ~avoid_exceptions:false ~closure_conversion:false
~monomorphize_types:false)
prg;
print_interpretation_results options Interpreter.interpret_program_dcalc prg print_interpretation_results options Interpreter.interpret_program_dcalc prg
(get_scopeopt_uid prg.decl_ctx ex_scope_opt) (get_scopeopt_uid prg.decl_ctx ex_scope_opt)
@ -781,7 +786,10 @@ module Commands = struct
Passes.lcalc options ~includes ~optimize ~check_invariants Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~monomorphize_types ~typed ~avoid_exceptions ~closure_conversion ~monomorphize_types ~typed
in in
Interpreter.load_runtime_modules prg; Interpreter.load_runtime_modules
~hashf:
(Hash.finalise ~avoid_exceptions ~closure_conversion ~monomorphize_types)
prg;
print_interpretation_results options Interpreter.interpret_program_lcalc prg print_interpretation_results options Interpreter.interpret_program_lcalc prg
(get_scopeopt_uid prg.decl_ctx ex_scope_opt) (get_scopeopt_uid prg.decl_ctx ex_scope_opt)
@ -844,7 +852,11 @@ module Commands = struct
Message.debug "Writing to %s..." Message.debug "Writing to %s..."
(Option.value ~default:"stdout" output_file); (Option.value ~default:"stdout" output_file);
let exec_scope = Option.map (get_scope_uid prg.decl_ctx) ex_scope_opt in let exec_scope = Option.map (get_scope_uid prg.decl_ctx) ex_scope_opt in
Lcalc.To_ocaml.format_program fmt prg ?exec_scope type_ordering let hashf =
Hash.finalise ~avoid_exceptions ~closure_conversion:false
~monomorphize_types:false
in
Lcalc.To_ocaml.format_program fmt prg ?exec_scope ~hashf type_ordering
let ocaml_cmd = let ocaml_cmd =
Cmd.v Cmd.v
@ -1010,7 +1022,7 @@ module Commands = struct
let prg = let prg =
Surface.Ast. Surface.Ast.
{ {
program_module_name = None; program_module = None;
program_items = []; program_items = [];
program_source_files = []; program_source_files = [];
program_used_modules = program_used_modules =
@ -1038,7 +1050,7 @@ module Commands = struct
in in
Format.open_hbox (); Format.open_hbox ();
Format.pp_print_list ~pp_sep:Format.pp_print_space Format.pp_print_list ~pp_sep:Format.pp_print_space
(fun ppf m -> (fun ppf (m, _) ->
let f = Pos.get_file (Mark.get (ModuleName.get_info m)) in let f = Pos.get_file (Mark.get (ModuleName.get_info m)) in
let f = let f =
match prefix with match prefix with

View File

@ -716,9 +716,21 @@ let commands = if commands = [] then entry_scopes else commands
name format_var var name) name format_var var name)
scopes_with_no_input scopes_with_no_input
let reexport_used_modules fmt modules = let check_and_reexport_used_modules fmt ~hashf modules =
List.iter List.iter
(fun m -> (fun (m, intf_id) ->
Format.fprintf fmt
"@[<hv 2>let () =@ @[<hov 2>match Runtime_ocaml.Runtime.check_module \
%S \"%a\"@ with@]@,\
| Ok () -> ()@,\
@[<hv 2>| Error h -> failwith \"Hash mismatch for module %a, it may \
need recompiling\"@]@]@,"
(ModuleName.to_string m)
(fun ppf h ->
if intf_id.is_external then
Format.pp_print_string ppf Hash.external_placeholder
else Hash.format ppf h)
(hashf intf_id.hash) ModuleName.format m;
Format.fprintf fmt "@[<hv 2>module %a@ = %a@]@," ModuleName.format m Format.fprintf fmt "@[<hv 2>module %a@ = %a@]@," ModuleName.format m
ModuleName.format m) ModuleName.format m)
modules modules
@ -726,7 +738,9 @@ let reexport_used_modules fmt modules =
let format_module_registration let format_module_registration
fmt fmt
(bnd : ('m Ast.expr Var.t * _) String.Map.t) (bnd : ('m Ast.expr Var.t * _) String.Map.t)
modname = modname
hash
is_external =
Format.pp_open_vbox fmt 2; Format.pp_open_vbox fmt 2;
Format.pp_print_string fmt "let () ="; Format.pp_print_string fmt "let () =";
Format.pp_print_space fmt (); Format.pp_print_space fmt ();
@ -743,11 +757,17 @@ let format_module_registration
(fun fmt (id, (var, _)) -> (fun fmt (id, (var, _)) ->
Format.fprintf fmt "@[<hov 2>%S,@ Obj.repr %a@]" id format_var var) Format.fprintf fmt "@[<hov 2>%S,@ Obj.repr %a@]" id format_var var)
fmt (String.Map.to_seq bnd); fmt (String.Map.to_seq bnd);
(* TODO: pass the visibility info down from desugared, and filter what is
exported here *)
Format.pp_close_box fmt (); Format.pp_close_box fmt ();
Format.pp_print_char fmt ' '; Format.pp_print_char fmt ' ';
Format.pp_print_string fmt "]"; Format.pp_print_string fmt "]";
Format.pp_print_space fmt (); Format.pp_print_space fmt ();
Format.pp_print_string fmt "\"todo-module-hash\""; Format.fprintf fmt "\"%a\""
(fun ppf h ->
if is_external then Format.pp_print_string ppf Hash.external_placeholder
else Hash.format ppf h)
hash;
Format.pp_close_box fmt (); Format.pp_close_box fmt ();
Format.pp_close_box fmt (); Format.pp_close_box fmt ();
Format.pp_print_newline fmt () Format.pp_print_newline fmt ()
@ -766,17 +786,21 @@ let format_program
(fmt : Format.formatter) (fmt : Format.formatter)
?exec_scope ?exec_scope
?(exec_args = true) ?(exec_args = true)
~(hashf : Hash.t -> Hash.full)
(p : 'm Ast.program) (p : 'm Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) : unit = (type_ordering : Scopelang.Dependency.TVertex.t list) : unit =
Format.pp_open_vbox fmt 0; Format.pp_open_vbox fmt 0;
Format.pp_print_string fmt header; Format.pp_print_string fmt header;
reexport_used_modules fmt (Program.modules_to_list p.decl_ctx.ctx_modules); check_and_reexport_used_modules fmt ~hashf
(Program.modules_to_list p.decl_ctx.ctx_modules);
format_ctx type_ordering fmt p.decl_ctx; format_ctx type_ordering fmt p.decl_ctx;
let bnd = format_code_items p.decl_ctx fmt p.code_items in let bnd = format_code_items p.decl_ctx fmt p.code_items in
Format.pp_print_cut fmt (); Format.pp_print_cut fmt ();
let () = let () =
match p.module_name, exec_scope with match p.module_name, exec_scope with
| Some modname, None -> format_module_registration fmt bnd modname | Some (modname, intf_id), None ->
format_module_registration fmt bnd modname (hashf intf_id.hash)
intf_id.is_external
| None, Some scope_name -> | None, Some scope_name ->
let scope_body = Program.get_scope_body p scope_name in let scope_body = Program.get_scope_body p scope_name in
format_scope_exec p.decl_ctx fmt bnd scope_name scope_body format_scope_exec p.decl_ctx fmt bnd scope_name scope_body

View File

@ -14,6 +14,7 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Catala_utils
open Shared_ast open Shared_ast
(** Formats a lambda calculus program into a valid OCaml program *) (** Formats a lambda calculus program into a valid OCaml program *)
@ -40,6 +41,7 @@ val format_program :
Format.formatter -> Format.formatter ->
?exec_scope:ScopeName.t -> ?exec_scope:ScopeName.t ->
?exec_args:bool -> ?exec_args:bool ->
hashf:(Hash.t -> Hash.full) ->
'm Ast.program -> 'm Ast.program ->
Scopelang.Dependency.TVertex.t list -> Scopelang.Dependency.TVertex.t list ->
unit unit

View File

@ -489,7 +489,7 @@ let run
(Option.value ~default:"stdout" jsoo_output_file); (Option.value ~default:"stdout" jsoo_output_file);
let modname = let modname =
match prg.module_name with match prg.module_name with
| Some m -> ModuleName.to_string m | Some (m, _) -> ModuleName.to_string m
| None -> | None ->
String.capitalize_ascii String.capitalize_ascii
Filename.( Filename.(

View File

@ -1381,7 +1381,10 @@ let run includes optimize ex_scope explain_options global_options =
Driver.Passes.dcalc global_options ~includes ~optimize Driver.Passes.dcalc global_options ~includes ~optimize
~check_invariants:false ~typed:Expr.typed ~check_invariants:false ~typed:Expr.typed
in in
Interpreter.load_runtime_modules prg; Interpreter.load_runtime_modules prg
~hashf:
(Hash.finalise ~avoid_exceptions:false ~closure_conversion:false
~monomorphize_types:false);
let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in
(* let result_expr, env = interpret_program prg scope in *) (* let result_expr, env = interpret_program prg scope in *)
let g, base_vars, env = program_to_graph explain_options prg scope in let g, base_vars, env = program_to_graph explain_options prg scope in

View File

@ -271,7 +271,10 @@ let run includes optimize check_invariants ex_scope options =
Driver.Passes.dcalc options ~includes ~optimize ~check_invariants Driver.Passes.dcalc options ~includes ~optimize ~check_invariants
~typed:Expr.typed ~typed:Expr.typed
in in
Interpreter.load_runtime_modules prg; Interpreter.load_runtime_modules prg
~hashf:
(Hash.finalise ~avoid_exceptions:false ~closure_conversion:false
~monomorphize_types:false);
let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in
let result_expr, _env = interpret_program prg scope in let result_expr, _env = interpret_program prg scope in
let fmt = Format.std_formatter in let fmt = Format.std_formatter in

View File

@ -121,5 +121,5 @@ type ctx = { decl_ctx : decl_ctx; modules : VarName.t ModuleName.Map.t }
type program = { type program = {
ctx : ctx; ctx : ctx;
code_items : code_item list; code_items : code_item list;
module_name : ModuleName.t option; module_name : (ModuleName.t * module_intf_id) option;
} }

View File

@ -659,7 +659,7 @@ let translate_program ~(config : translation_config) (p : 'm L.program) :
A.program = A.program =
let modules = let modules =
List.fold_left List.fold_left
(fun acc m -> (fun acc (m, _) ->
let vname = Mark.map (( ^ ) "Module_") (ModuleName.get_info m) in let vname = Mark.map (( ^ ) "Module_") (ModuleName.get_info m) in
(* The "Module_" prefix is a workaround name clashes for same-name (* The "Module_" prefix is a workaround name clashes for same-name
structs and modules, Python in particular mixes everything in one structs and modules, Python in particular mixes everything in one

View File

@ -21,10 +21,10 @@ open Ast
let needs_parens (_e : expr) : bool = false let needs_parens (_e : expr) : bool = false
let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit = let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit =
Format.fprintf fmt "%a_%d" VarName.format v (VarName.hash v) Format.fprintf fmt "%a_%d" VarName.format v (VarName.id v)
let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit = let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit =
Format.fprintf fmt "@{<green>%a_%d@}" FuncName.format v (FuncName.hash v) Format.fprintf fmt "@{<green>%a_%d@}" FuncName.format v (FuncName.id v)
let rec format_expr let rec format_expr
(decl_ctx : decl_ctx) (decl_ctx : decl_ctx)

View File

@ -96,11 +96,11 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty
let format_var (fmt : Format.formatter) (v : VarName.t) : unit = let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
let v_str = Mark.remove (VarName.get_info v) in let v_str = Mark.remove (VarName.get_info v) in
let hash = VarName.hash v in let id = VarName.id v in
let local_id = let local_id =
match StringMap.find_opt v_str !string_counter_map with match StringMap.find_opt v_str !string_counter_map with
| Some ids -> ( | Some ids -> (
match IntMap.find_opt hash ids with match IntMap.find_opt id ids with
| None -> | None ->
let max_id = let max_id =
snd snd
@ -111,13 +111,13 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
in in
string_counter_map := string_counter_map :=
StringMap.add v_str StringMap.add v_str
(IntMap.add hash (max_id + 1) ids) (IntMap.add id (max_id + 1) ids)
!string_counter_map; !string_counter_map;
max_id + 1 max_id + 1
| Some local_id -> local_id) | Some local_id -> local_id)
| None -> | None ->
string_counter_map := string_counter_map :=
StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; StringMap.add v_str (IntMap.singleton id 0) !string_counter_map;
0 0
in in
if v_str = "_" then Format.fprintf fmt "dummy_var" if v_str = "_" then Format.fprintf fmt "dummy_var"

View File

@ -152,20 +152,20 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty
let format_var (fmt : Format.formatter) (v : VarName.t) : unit = let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
let v_str = clean_name (Mark.remove (VarName.get_info v)) in let v_str = clean_name (Mark.remove (VarName.get_info v)) in
let hash = VarName.hash v in let id = VarName.id v in
let local_id = let local_id =
match StringMap.find_opt v_str !string_counter_map with match StringMap.find_opt v_str !string_counter_map with
| Some ids -> ( | Some ids -> (
match IntMap.find_opt hash ids with match IntMap.find_opt id ids with
| None -> | None ->
let id = 1 + IntMap.fold (fun _ -> Int.max) ids 0 in let local_id = 1 + IntMap.fold (fun _ -> Int.max) ids 0 in
string_counter_map := string_counter_map :=
StringMap.add v_str (IntMap.add hash id ids) !string_counter_map; StringMap.add v_str (IntMap.add id local_id ids) !string_counter_map;
id local_id
| Some local_id -> local_id) | Some local_id -> local_id)
| None -> | None ->
string_counter_map := string_counter_map :=
StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; StringMap.add v_str (IntMap.singleton id 0) !string_counter_map;
0 0
in in
if v_str = "_" then Format.fprintf fmt "_" if v_str = "_" then Format.fprintf fmt "_"

View File

@ -220,11 +220,11 @@ let string_counter_map : int IntMap.t StringMap.t ref = ref StringMap.empty
let format_var (fmt : Format.formatter) (v : VarName.t) : unit = let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
let v_str = Mark.remove (VarName.get_info v) in let v_str = Mark.remove (VarName.get_info v) in
let hash = VarName.hash v in let id = VarName.id v in
let local_id = let local_id =
match StringMap.find_opt v_str !string_counter_map with match StringMap.find_opt v_str !string_counter_map with
| Some ids -> ( | Some ids -> (
match IntMap.find_opt hash ids with match IntMap.find_opt id ids with
| None -> | None ->
let max_id = let max_id =
snd snd
@ -235,13 +235,13 @@ let format_var (fmt : Format.formatter) (v : VarName.t) : unit =
in in
string_counter_map := string_counter_map :=
StringMap.add v_str StringMap.add v_str
(IntMap.add hash (max_id + 1) ids) (IntMap.add id (max_id + 1) ids)
!string_counter_map; !string_counter_map;
max_id + 1 max_id + 1
| Some local_id -> local_id) | Some local_id -> local_id)
| None -> | None ->
string_counter_map := string_counter_map :=
StringMap.add v_str (IntMap.singleton hash 0) !string_counter_map; StringMap.add v_str (IntMap.singleton id 0) !string_counter_map;
0 0
in in
if v_str = "_" then Format.fprintf fmt "dummy_var" if v_str = "_" then Format.fprintf fmt "dummy_var"

View File

@ -67,7 +67,7 @@ type 'm scope_decl = {
} }
type 'm program = { type 'm program = {
program_module_name : ModuleName.t option; program_module_name : (ModuleName.t * module_intf_id) option;
program_ctx : decl_ctx; program_ctx : decl_ctx;
program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t; program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t;
program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t;

View File

@ -63,12 +63,13 @@ type 'm scope_decl = {
} }
type 'm program = { type 'm program = {
program_module_name : ModuleName.t option; program_module_name : (ModuleName.t * module_intf_id) option;
program_ctx : decl_ctx; program_ctx : decl_ctx;
program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t; program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t;
(* Using [nil] here ensure that program interfaces don't contain any (* Using [nil] here ensure that program interfaces don't contain any
expressions. They won't contain any rules or topdefs, but will still have expressions. They won't contain any rules or topdef implementations, but
the scope signatures needed to respect the call convention *) will still have the scope signatures needed to respect the call
convention *)
program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t;
program_topdefs : ('m expr * typ) TopdefName.Map.t; program_topdefs : ('m expr * typ) TopdefName.Map.t;
program_lang : Global.backend_lang; program_lang : Global.backend_lang;

View File

@ -42,9 +42,7 @@ module SVertex = struct
| Topdef g1, Topdef g2 -> TopdefName.equal g1 g2 | Topdef g1, Topdef g2 -> TopdefName.equal g1 g2
| (Scope _ | Topdef _), _ -> false | (Scope _ | Topdef _), _ -> false
let hash = function let hash = function Scope s -> ScopeName.id s | Topdef g -> TopdefName.id g
| Scope s -> ScopeName.hash s
| Topdef g -> TopdefName.hash g
let format ppf = function let format ppf = function
| Scope s -> ScopeName.format ppf s | Scope s -> ScopeName.format ppf s
@ -206,7 +204,9 @@ module TVertex = struct
type t = Struct of StructName.t | Enum of EnumName.t type t = Struct of StructName.t | Enum of EnumName.t
let hash x = let hash x =
match x with Struct x -> StructName.hash x | Enum x -> EnumName.hash x match x with
| Struct x -> StructName.id x
| Enum x -> Hashtbl.hash (`Enum (EnumName.id x))
let compare x y = let compare x y =
match x, y with match x, y with

View File

@ -953,8 +953,9 @@ let translate_program
let program_topdefs = let program_topdefs =
TopdefName.Map.mapi TopdefName.Map.mapi
(fun id -> function (fun id -> function
| Some e, ty -> Expr.unbox (translate_expr ctx e), ty | { D.topdef_expr = Some e; topdef_type = ty; topdef_visibility = _ } ->
| None, (_, pos) -> Expr.unbox (translate_expr ctx e), ty
| { D.topdef_expr = None; topdef_type = _, pos; _ } ->
Message.error ~pos "No definition found for %a" TopdefName.format id) Message.error ~pos "No definition found for %a" TopdefName.format id)
desugared.program_root.module_topdefs desugared.program_root.module_topdefs
in in
@ -964,8 +965,7 @@ let translate_program
desugared.D.program_root.module_scopes desugared.D.program_root.module_scopes
in in
{ {
Ast.program_module_name = Ast.program_module_name = desugared.D.program_module_name;
Option.map ModuleName.fresh desugared.D.program_module_name;
Ast.program_topdefs; Ast.program_topdefs;
Ast.program_scopes; Ast.program_scopes;
Ast.program_ctx = ctx.decl_ctx; Ast.program_ctx = ctx.decl_ctx;

View File

@ -668,8 +668,14 @@ type scope_info = {
out_struct_fields : StructField.t ScopeVar.Map.t; out_struct_fields : StructField.t ScopeVar.Map.t;
} }
type module_intf_id = { hash : Hash.t; is_external : bool }
type module_tree_node = { deps : module_tree; intf_id : module_intf_id }
and module_tree = module_tree_node ModuleName.Map.t
(** In practice, this is a DAG: beware of repeated names *) (** In practice, this is a DAG: beware of repeated names *)
type module_tree = M of module_tree ModuleName.Map.t [@@caml.unboxed]
type visibility = Private | Public
type decl_ctx = { type decl_ctx = {
ctx_enums : enum_ctx; ctx_enums : enum_ctx;
@ -688,5 +694,5 @@ type 'e program = {
decl_ctx : decl_ctx; decl_ctx : decl_ctx;
code_items : 'e code_item_list; code_items : 'e code_item_list;
lang : Global.backend_lang; lang : Global.backend_lang;
module_name : ModuleName.t option; module_name : (ModuleName.t * module_intf_id) option;
} }

View File

@ -1155,29 +1155,57 @@ let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list
reflect that. *) reflect that. *)
let evaluate_expr ctx lang e = evaluate_expr ctx lang (addcustom e) let evaluate_expr ctx lang e = evaluate_expr ctx lang (addcustom e)
let load_runtime_modules prg = let load_runtime_modules ~hashf prg =
let load m = let load (mname, intf_id) =
let hash = hashf intf_id.hash in
let expect_hash =
if intf_id.is_external then Hash.external_placeholder
else Hash.to_string hash
in
let obj_file = let obj_file =
Dynlink.adapt_filename Dynlink.adapt_filename
File.(Pos.get_file (Mark.get (ModuleName.get_info m)) -.- "cmo") File.(Pos.get_file (Mark.get (ModuleName.get_info mname)) -.- "cmo")
in in
if not (Sys.file_exists obj_file) then (if not (Sys.file_exists obj_file) then
Message.error Message.error
~pos_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here") ~pos_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here")
~pos:(Mark.get (ModuleName.get_info m)) ~pos:(Mark.get (ModuleName.get_info mname))
"Compiled OCaml object %a@ not@ found.@ Make sure it has been suitably \ "Compiled OCaml object %a@ not@ found.@ Make sure it has been \
compiled." suitably compiled."
File.format obj_file File.format obj_file
else else
try Dynlink.loadfile obj_file try Dynlink.loadfile obj_file
with Dynlink.Error dl_err -> with Dynlink.Error dl_err ->
Message.error "Error loading compiled module from %a:@;<1 2>@[<hov>%a@]" Message.error
"While loading compiled module from %a:@;<1 2>@[<hov>%a@]"
File.format obj_file Format.pp_print_text File.format obj_file Format.pp_print_text
(Dynlink.error_message dl_err) (Dynlink.error_message dl_err));
match Runtime.check_module (ModuleName.to_string mname) expect_hash with
| Ok () -> ()
| Error bad_hash ->
Message.debug
"Module hash mismatch for %a:@ @[<v>Expected: %a@,Found: %a@]"
ModuleName.format mname Hash.format hash
(fun ppf h ->
try Hash.format ppf (Hash.of_string h)
with Failure _ ->
if h = Hash.external_placeholder then
Format.fprintf ppf "@{<cyan>%s@}" Hash.external_placeholder
else Format.fprintf ppf "@{<red><invalid>@}")
bad_hash;
Message.error
"Module %a@ needs@ recompiling:@ %a@ was@ likely@ compiled@ from@ an@ \
older@ version@ or@ with@ incompatible@ flags."
ModuleName.format mname File.format obj_file
| exception Not_found ->
Message.error
"Module %a@ was loaded from file %a but did not register properly, \
there is something wrong in its code."
ModuleName.format mname File.format obj_file
in in
let modules_list_topo = Program.modules_to_list prg.decl_ctx.ctx_modules in let modules_list_topo = Program.modules_to_list prg.decl_ctx.ctx_modules in
if modules_list_topo <> [] then if modules_list_topo <> [] then
Message.debug "Loading shared modules... %a" Message.debug "Loading shared modules... %a"
(Format.pp_print_list ~pp_sep:Format.pp_print_space ModuleName.format) (Format.pp_print_list ~pp_sep:Format.pp_print_space ModuleName.format)
modules_list_topo; (List.map (fun (m, _) -> m) modules_list_topo);
List.iter load modules_list_topo List.iter load modules_list_topo

View File

@ -62,6 +62,6 @@ val delcustom :
(** Runtime check that the term contains no custom terms (raises (** Runtime check that the term contains no custom terms (raises
[Invalid_argument] if that is the case *) [Invalid_argument] if that is the case *)
val load_runtime_modules : _ program -> unit val load_runtime_modules : hashf:(Hash.t -> Hash.full) -> _ program -> unit
(** Dynlink the runtime modules required by the given program, in order to make (** Dynlink the runtime modules required by the given program, in order to make
them callable by the interpreter. *) them callable by the interpreter. *)

View File

@ -58,7 +58,7 @@ let empty_ctx =
ctx_struct_fields = Ident.Map.empty; ctx_struct_fields = Ident.Map.empty;
ctx_enum_constrs = Ident.Map.empty; ctx_enum_constrs = Ident.Map.empty;
ctx_scope_index = Ident.Map.empty; ctx_scope_index = Ident.Map.empty;
ctx_modules = M ModuleName.Map.empty; ctx_modules = ModuleName.Map.empty;
} }
let get_scope_body { code_items; _ } scope = let get_scope_body { code_items; _ } scope =
@ -87,11 +87,11 @@ let to_expr p main_scope =
res res
let modules_to_list (mt : module_tree) = let modules_to_list (mt : module_tree) =
let rec aux acc (M mtree) = let rec aux acc mtree =
ModuleName.Map.fold ModuleName.Map.fold
(fun mname sub acc -> (fun mname mnode acc ->
if List.exists (ModuleName.equal mname) acc then acc if List.exists (fun (m, _) -> ModuleName.equal m mname) acc then acc
else mname :: aux acc sub) else (mname, mnode.intf_id) :: aux acc mnode.deps)
mtree acc mtree acc
in in
List.rev (aux [] mt) List.rev (aux [] mt)

View File

@ -53,5 +53,6 @@ val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed
val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body val find_scope : ScopeName.t -> 'e code_item_list -> 'e scope_body
val modules_to_list : module_tree -> ModuleName.t list val modules_to_list : module_tree -> (ModuleName.t * module_intf_id) list
(** Returns a list of used modules, in topological order *) (** Returns a list of used modules, in topological order ; the boolean indicates
if the module is external *)

View File

@ -93,6 +93,22 @@ let rec compare ty1 ty2 =
| TClosureEnv, _ -> -1 | TClosureEnv, _ -> -1
| _, TClosureEnv -> 1 | _, TClosureEnv -> 1
let rec hash ~strip ty =
let open Hash.Op in
match Mark.remove ty with
| TLit l -> !`TLit % !(l : typ_lit)
| TTuple tl -> List.fold_left (fun acc ty -> acc % hash ~strip ty) !`TTuple tl
| TStruct n -> !`TStruct % StructName.hash ~strip n
| TEnum n -> !`TEnum % EnumName.hash ~strip n
| TOption ty -> !`TOption % hash ~strip ty
| TArrow (tl, ty) ->
!`TArrow
% List.fold_left (fun acc ty -> acc % hash ~strip ty) (hash ~strip ty) tl
| TArray ty -> !`TArray % hash ~strip ty
| TDefault ty -> !`TDefault % hash ~strip ty
| TAny -> !`TAny
| TClosureEnv -> !`TClosureEnv
let rec arrow_return = function TArrow (_, b), _ -> arrow_return b | t -> t let rec arrow_return = function TArrow (_, b), _ -> arrow_return b | t -> t
let format = Print.typ_debug let format = Print.typ_debug

View File

@ -14,6 +14,8 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Catala_utils
type t = Definitions.typ type t = Definitions.typ
val format : Format.formatter -> t -> unit val format : Format.formatter -> t -> unit
@ -23,6 +25,11 @@ module Map : Catala_utils.Map.S with type key = t
val equal : t -> t -> bool val equal : t -> t -> bool
val equal_list : t list -> t list -> bool val equal_list : t list -> t list -> bool
val compare : t -> t -> int val compare : t -> t -> int
val hash : strip:Uid.Path.t -> t -> Hash.t
(** The [strip] argument strips the given leading path components in included
identifiers before hashing *)
val unifiable : t -> t -> bool val unifiable : t -> t -> bool
val unifiable_list : t list -> t list -> bool val unifiable_list : t list -> t list -> bool

View File

@ -31,6 +31,7 @@ module Any =
let format fmt () = Format.fprintf fmt "any" let format fmt () = Format.fprintf fmt "any"
let equal () () = true let equal () () = true
let compare () () = 0 let compare () () = 0
let hash () = Hash.raw `Any
end) end)
(struct (struct
let style = Ocolor_types.(Fg (C4 hi_magenta)) let style = Ocolor_types.(Fg (C4 hi_magenta))
@ -166,7 +167,7 @@ let rec format_typ
format_typ ~colors fmt t1; format_typ ~colors fmt t1;
Format.pp_print_as fmt 1 "" Format.pp_print_as fmt 1 ""
| TAny v -> | TAny v ->
if Global.options.debug then Format.fprintf fmt "<a%d>" (Any.hash v) if Global.options.debug then Format.fprintf fmt "<a%d>" (Any.id v)
else Format.pp_print_string fmt "<any>" else Format.pp_print_string fmt "<any>"
| TClosureEnv -> Format.fprintf fmt "closure_env" | TClosureEnv -> Format.fprintf fmt "closure_env"

View File

@ -318,7 +318,7 @@ and law_structure =
| CodeBlock of code_block * source_repr * bool (* Metadata if true *) | CodeBlock of code_block * source_repr * bool (* Metadata if true *)
and interface = { and interface = {
intf_modname : uident Mark.pos; intf_modname : program_module;
intf_code : code_block; intf_code : code_block;
(** Invariant: an interface shall only contain [*Decl] elements, or (** Invariant: an interface shall only contain [*Decl] elements, or
[Topdef] elements with [topdef_expr = None] *) [Topdef] elements with [topdef_expr = None] *)
@ -330,8 +330,10 @@ and module_use = {
mod_use_alias : uident Mark.pos; mod_use_alias : uident Mark.pos;
} }
and program_module = { module_name : uident Mark.pos; module_external : bool }
and program = { and program = {
program_module_name : uident Mark.pos option; program_module : program_module option;
program_items : law_structure list; program_items : law_structure list;
program_source_files : (string[@opaque]) list; program_source_files : (string[@opaque]) list;
program_used_modules : module_use list; program_used_modules : module_use list;

View File

@ -259,18 +259,21 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
List.fold_left List.fold_left
(fun acc command -> (fun acc command ->
let join_module_names name_opt = let join_module_names name_opt =
match acc.Ast.program_module_name, name_opt with match acc.Ast.program_module, name_opt with
| opt, None | None, opt -> opt | opt, None | None, opt -> opt
| Some id1, Some id2 -> | Some id1, Some id2 ->
Message.error Message.error
~extra_pos:["", Mark.get id1; "", Mark.get id2] ~extra_pos:
["", Mark.get id1.module_name; "", Mark.get id2.module_name]
"Multiple definitions of the module name" "Multiple definitions of the module name"
in in
match command with match command with
| Ast.ModuleDef (id, _) -> | Ast.ModuleDef (id, is_external) ->
{ {
acc with acc with
Ast.program_module_name = join_module_names (Some id); Ast.program_module =
join_module_names
(Some { module_name = id; module_external = is_external });
Ast.program_items = command :: acc.Ast.program_items; Ast.program_items = command :: acc.Ast.program_items;
} }
| Ast.ModuleUse (mod_use_name, alias) -> | Ast.ModuleUse (mod_use_name, alias) ->
@ -288,22 +291,22 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
@@ fun lexbuf -> @@ fun lexbuf ->
let includ_program = parse_source lexbuf in let includ_program = parse_source lexbuf in
let () = let () =
includ_program.Ast.program_module_name includ_program.Ast.program_module
|> Option.iter |> Option.iter
@@ fun id -> @@ fun id ->
Message.error Message.error
~extra_pos: ~extra_pos:
[ [
"File include", Mark.get inc_file; "File include", Mark.get inc_file;
"Module declaration", Mark.get id; "Module declaration", Mark.get id.Ast.module_name;
] ]
"A file that declares a module cannot be used through the raw \ "A file that declares a module cannot be used through the raw \
'@{<yellow>> Include@}'@ directive.@ You should use it as a \ '@{<yellow>> Include@}'@ directive.@ You should use it as a \
module with@ '@{<yellow>> Use @{<blue>%s@}@}'@ instead." module with@ '@{<yellow>> Use @{<blue>%s@}@}'@ instead."
(Mark.remove id) (Mark.remove id.Ast.module_name)
in in
{ {
Ast.program_module_name = acc.program_module_name; Ast.program_module = acc.program_module;
Ast.program_source_files = Ast.program_source_files =
List.rev_append includ_program.program_source_files List.rev_append includ_program.program_source_files
acc.Ast.program_source_files; acc.Ast.program_source_files;
@ -316,7 +319,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
} }
| Ast.LawHeading (heading, commands') -> | Ast.LawHeading (heading, commands') ->
let { let {
Ast.program_module_name; Ast.program_module;
Ast.program_items = commands'; Ast.program_items = commands';
Ast.program_source_files = new_sources; Ast.program_source_files = new_sources;
Ast.program_used_modules = new_used_modules; Ast.program_used_modules = new_used_modules;
@ -325,7 +328,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
expand_includes source_file commands' expand_includes source_file commands'
in in
{ {
Ast.program_module_name = join_module_names program_module_name; Ast.program_module = join_module_names program_module;
Ast.program_source_files = Ast.program_source_files =
List.rev_append new_sources acc.Ast.program_source_files; List.rev_append new_sources acc.Ast.program_source_files;
Ast.program_items = Ast.program_items =
@ -336,7 +339,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
} }
| i -> { acc with Ast.program_items = i :: acc.Ast.program_items }) | i -> { acc with Ast.program_items = i :: acc.Ast.program_items })
{ {
Ast.program_module_name = None; Ast.program_module = None;
Ast.program_source_files = []; Ast.program_source_files = [];
Ast.program_items = []; Ast.program_items = [];
Ast.program_used_modules = []; Ast.program_used_modules = [];
@ -346,7 +349,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
in in
{ {
Ast.program_lang = language; Ast.program_lang = language;
Ast.program_module_name = rprg.Ast.program_module_name; Ast.program_module = rprg.Ast.program_module;
Ast.program_source_files = List.rev rprg.Ast.program_source_files; Ast.program_source_files = List.rev rprg.Ast.program_source_files;
Ast.program_items = List.rev rprg.Ast.program_items; Ast.program_items = List.rev rprg.Ast.program_items;
Ast.program_used_modules = List.rev rprg.Ast.program_used_modules; Ast.program_used_modules = List.rev rprg.Ast.program_used_modules;
@ -396,8 +399,8 @@ let with_sedlex_source source_file f =
f lexbuf f lexbuf
let check_modname program source_file = let check_modname program source_file =
match program.Ast.program_module_name, source_file with match program.Ast.program_module, source_file with
| ( Some (mname, pos), | ( Some { module_name = mname, pos; _ },
(Global.FileName file | Global.Contents (_, file) | Global.Stdin file) ) (Global.FileName file | Global.Contents (_, file) | Global.Stdin file) )
when not File.(equal mname Filename.(remove_extension (basename file))) -> when not File.(equal mname Filename.(remove_extension (basename file))) ->
Message.error ~pos Message.error ~pos
@ -413,10 +416,14 @@ let load_interface ?default_module_name source_file =
let program = with_sedlex_source source_file parse_source in let program = with_sedlex_source source_file parse_source in
check_modname program source_file; check_modname program source_file;
let modname = let modname =
match program.Ast.program_module_name, default_module_name with match program.Ast.program_module, default_module_name with
| Some mname, _ -> mname | Some mname, _ -> mname
| None, Some n -> | None, Some n ->
n, Pos.from_info (Global.input_src_file source_file) 0 0 0 0 {
module_name =
n, Pos.from_info (Global.input_src_file source_file) 0 0 0 0;
module_external = false;
}
| None, None -> | None, None ->
Message.error Message.error
"%a doesn't define a module name. It should contain a '@{<cyan>> \ "%a doesn't define a module name. It should contain a '@{<cyan>> \

View File

@ -31,10 +31,10 @@ catala implementation and compile to OCaml (removing the `external` directive):
``` ```
```shell-session ```shell-session
$ clerk build _build/.../Prorata_external.ml $ clerk build _build/.../prorata_external.ml
``` ```
(beware the `_build/`, and the capitalisation of the module name) (beware the `_build/`, it is required here)
## Write the OCaml implementation ## Write the OCaml implementation
@ -44,9 +44,11 @@ capitalisation to match). Edit to replace the dummy implementation by your code.
Refer to `runtimes/ocaml/runtime.mli` for what is available (especially the Refer to `runtimes/ocaml/runtime.mli` for what is available (especially the
`Oper` module to manipulate the types). `Oper` module to manipulate the types).
Keep the `register_module` at the end as is, it's needed for the toplevel to use Keep the `register_module` at the end, but replace the hash (which should be of
the value (you would get `Failure("Could not resolve reference to Xxx")` during the form `"CM0|XXXXXXXX|XXXXXXXX|XXXXXXXX"`) by the string `"*external*"`. This
evaluation). section is needed for the Catala interpreter to find the declared values --- the
error `Failure("Could not resolve reference to Xxx")` during evaluation is a
symptom that it is missing.
## Compile and test ## Compile and test

View File

@ -897,7 +897,9 @@ let register_module modname values hash =
Hashtbl.add modules_table modname hash; Hashtbl.add modules_table modname hash;
List.iter (fun (id, v) -> Hashtbl.add values_table ([modname], id) v) values List.iter (fun (id, v) -> Hashtbl.add values_table ([modname], id) v) values
let check_module m h = String.equal (Hashtbl.find modules_table m) h let check_module m h =
let h1 = Hashtbl.find modules_table m in
if String.equal h h1 then Ok () else Error h1
let lookup_value qid = let lookup_value qid =
try Hashtbl.find values_table qid try Hashtbl.find values_table qid

View File

@ -446,8 +446,8 @@ val register_module : string -> (string * Obj.t) list -> hash -> unit
expected to be a hash of the source file and the Catala version, and will in expected to be a hash of the source file and the Catala version, and will in
time be used to ensure that the module and the interface are in sync *) time be used to ensure that the module and the interface are in sync *)
val check_module : string -> hash -> bool val check_module : string -> hash -> (unit, hash) result
(** Returns [true] if it has been registered with the correct hash, [false] if (** Returns [Ok] if it has been registered with the correct hash, [Error h] if
there is a hash mismatch. there is a hash mismatch.
@raise Not_found if the module does not exist at all *) @raise Not_found if the module does not exist at all *)

View File

@ -19,12 +19,20 @@ declaration scope S:
declaration half content decimal declaration half content decimal
depends on x content integer depends on x content integer
equals x / 2 equals x / 2
declaration maybe content Enum1
depends on x content Enum1
``` ```
```catala ```catala
scope S: scope S:
definition sr equals $1,000 definition sr equals $1,000
definition e1 equals Maybe definition e1 equals Maybe
declaration maybe content Enum1
depends on x content Enum1
equals Maybe
``` ```

View File

@ -30,7 +30,7 @@ let s (s_in: S_in.t) : S.t =
try try
(handle_default (handle_default
[|{filename="tests/modules/good/mod_def.catala_en"; [|{filename="tests/modules/good/mod_def.catala_en";
start_line=26; start_column=24; end_line=26; end_column=30; start_line=29; start_column=24; end_line=29; end_column=30;
law_headings=["Test modules + inclusions 1"]}|] law_headings=["Test modules + inclusions 1"]}|]
([|(fun (_: unit) -> ([|(fun (_: unit) ->
handle_default [||] ([||]) (fun (_: unit) -> true) handle_default [||] ([||]) (fun (_: unit) -> true)
@ -47,7 +47,7 @@ let s (s_in: S_in.t) : S.t =
try try
(handle_default (handle_default
[|{filename="tests/modules/good/mod_def.catala_en"; [|{filename="tests/modules/good/mod_def.catala_en";
start_line=27; start_column=24; end_line=27; end_column=29; start_line=30; start_column=24; end_line=30; end_column=29;
law_headings=["Test modules + inclusions 1"]}|] law_headings=["Test modules + inclusions 1"]}|]
([|(fun (_: unit) -> ([|(fun (_: unit) ->
handle_default [||] ([||]) (fun (_: unit) -> true) handle_default [||] ([||]) (fun (_: unit) -> true)
@ -70,8 +70,12 @@ let half_ : integer -> decimal =
law_headings=["Test modules + inclusions 1"]} x_ (integer_of_string law_headings=["Test modules + inclusions 1"]} x_ (integer_of_string
"2") "2")
let maybe_ : Enum1.t -> Enum1.t =
fun (_: Enum1.t) -> Enum1.Maybe ()
let () = let () =
Runtime_ocaml.Runtime.register_module "Mod_def" Runtime_ocaml.Runtime.register_module "Mod_def"
[ "S", Obj.repr s; [ "S", Obj.repr s;
"half", Obj.repr half_ ] "half", Obj.repr half_;
"todo-module-hash" "maybe", Obj.repr maybe_ ]
"CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX"

View File

@ -37,4 +37,4 @@ let () =
Runtime_ocaml.Runtime.register_module "Prorata_external" Runtime_ocaml.Runtime.register_module "Prorata_external"
[ "prorata", Obj.repr prorata_; [ "prorata", Obj.repr prorata_;
"prorata2", Obj.repr prorata2_ ] "prorata2", Obj.repr prorata2_ ]
"todo-module-hash" "*external*"

View File

@ -90,5 +90,5 @@ let s (s_in: S_in.t) : S.t =
let () = let () =
Runtime_ocaml.Runtime.register_module "Let_in2" Runtime_ocaml.Runtime.register_module "Let_in2"
[ "S", Obj.repr s ] [ "S", Obj.repr s ]
"todo-module-hash" "CMX|XXXXXXXX|XXXXXXXX|XXXXXXXX"
``` ```