diff --git a/build_system/clerk_driver.ml b/build_system/clerk_driver.ml index bb91256b..8ad5b0a8 100644 --- a/build_system/clerk_driver.ml +++ b/build_system/clerk_driver.ml @@ -545,6 +545,14 @@ let[@ocamlformat "disable"] static_base_rules = "fi"; ] ~description:[""; !output]; + (* Note: this last rule looks horrible, but the processing is pretty simple: + in the rules above, we output the returning code of diffing individual + tests to a [@test] file, then the rules for directories just + concat these files. What this last rule does is then just count the number + of `0` and the total number of characters in the file, and print a readable + message. Instead of this disgusting shell code embedded in the ninja file, + this could be a specialised subcommand of clerk, e.g. `clerk + test-diagnostic ` *) ] let gen_build_statements @@ -641,7 +649,7 @@ let gen_build_statements (if Filename.is_relative d then !Var.builddir / d else d); ]) include_dirs - @ (List.map (fun m -> m ^".cmx") modules) ); + @ List.map (fun m -> m ^ ".cmx") modules ); ] in let expose_module = @@ -694,6 +702,7 @@ let gen_build_statements diff; it should actually be an output for the cases when we reset but that shouldn't cause trouble. *) Nj.build "post-test" ~inputs:[reference; test_out] + ~implicit_in:["always"] ~outputs:[reference ^ "@post"] :: acc) [] item.legacy_tests @@ -720,7 +729,8 @@ let gen_build_statements ~outputs:[inc (srcv ^ "@test")] ~inputs:[srcv; inc (srcv ^ "@out")] ~implicit_in: - (List.map + ("always" :: + List.map (fun test -> legacy_test_reference test ^ "@post") item.legacy_tests); results; @@ -801,7 +811,8 @@ let gen_ninja_file catala_exe catala_flags build_dir include_dirs dir = @+ List.to_seq (base_bindings catala_exe catala_flags build_dir include_dirs) @+ Seq.return (Nj.Comment "\n- Base rules - #\n") @+ List.to_seq static_base_rules - @+ Seq.return (Nj.Comment "- Project-specific build statements - #") + @+ Seq.return (Nj.build "phony" ~outputs:["always"]) + @+ Seq.return (Nj.Comment "\n- Project-specific build statements - #") @+ build_statements include_dirs dir @+ Seq.return (Nj.build "phony" ~outputs:["test"] ~inputs:[".@test"]) diff --git a/compiler/catala_utils/map.ml b/compiler/catala_utils/map.ml index 2d56a8c2..1d54d8a8 100644 --- a/compiler/catala_utils/map.ml +++ b/compiler/catala_utils/map.ml @@ -36,6 +36,7 @@ module type S = sig val keys : 'a t -> key list val values : 'a t -> 'a list val of_list : (key * 'a) list -> 'a t + val disjoint_union : 'a t -> 'a t -> 'a t val format_keys : ?pp_sep:(Format.formatter -> unit -> unit) -> @@ -87,6 +88,12 @@ module Make (Ord : OrderedType) : S with type key = Ord.t = struct let keys t = fold (fun k _ acc -> k :: acc) t [] |> List.rev let values t = fold (fun _ v acc -> v :: acc) t [] |> List.rev let of_list l = List.fold_left (fun m (k, v) -> add k v m) empty l + let disjoint_union t1 t2 = + union (fun k _ _ -> + Format.kasprintf failwith + "Maps are not disjoint: conflict on key %a" + Ord.format k) + t1 t2 let format_keys ?pp_sep ppf t = Format.pp_print_list ?pp_sep Ord.format ppf (keys t) diff --git a/compiler/catala_utils/uid.ml b/compiler/catala_utils/uid.ml index 3612bb12..123f066c 100644 --- a/compiler/catala_utils/uid.ml +++ b/compiler/catala_utils/uid.ml @@ -32,6 +32,7 @@ module type Id = sig val compare : t -> t -> int val equal : t -> t -> bool val format : Format.formatter -> t -> unit + val to_string : t -> string val hash : t -> int module Set : Set.S with type elt = t @@ -68,6 +69,8 @@ module Make (X : Info) (S : Style) () : Id with type info = X.info = struct let get_info (uid : t) : X.info = uid.info let hash (x : t) : int = x.id + let to_string t = X.to_string t.info + module Set = Set.Make (Ordering) module Map = Map.Make (Ordering) end @@ -87,27 +90,12 @@ module Gen (S : Style) () = Make (MarkedString) (S) () (* - Modules, paths and qualified idents - *) -module Module = struct - module Ordering = struct - type t = string Mark.pos - - let equal = Mark.equal String.equal - let compare = Mark.compare String.compare - let format ppf m = Format.fprintf ppf "@{%s@}" (Mark.remove m) - end - - include Ordering - - let to_string m = Mark.remove m - let of_string m = m - let pos m = Mark.get m - - module Set = Set.Make (Ordering) - module Map = Map.Make (Ordering) -end -(* TODO: should probably be turned into an uid once we implement module import - directives; that will incur an additional resolution work on all paths though - ([module Module = Gen ()]) *) +module Module = + Gen + (struct + let style = Ocolor_types.(Fg (C4 blue)) + end) + () module Path = struct type t = Module.t list diff --git a/compiler/catala_utils/uid.mli b/compiler/catala_utils/uid.mli index 862b856c..4ed2e8d7 100644 --- a/compiler/catala_utils/uid.mli +++ b/compiler/catala_utils/uid.mli @@ -47,6 +47,7 @@ module type Id = sig val compare : t -> t -> int val equal : t -> t -> bool val format : Format.formatter -> t -> unit + val to_string : t -> string val hash : t -> int module Set : Set.S with type elt = t @@ -62,27 +63,14 @@ end (** This is the generative functor that ensures that two modules resulting from two different calls to [Make] will be viewed as different types [t] by the OCaml typechecker. Prevents mixing up different sorts of identifiers. *) -module Make (X : Info) (S : Style) () : Id with type info = X.info +module Make (X : Info) (_ : Style) () : Id with type info = X.info (** Shortcut for creating a kind of uids over marked strings *) -module Gen (S : Style) () : Id with type info = MarkedString.info +module Gen (_ : Style) () : Id with type info = MarkedString.info (** {2 Handling of Uids with additional path information} *) -module Module : sig - type t = private string Mark.pos - (* TODO: this will become an uid at some point *) - - val to_string : t -> string - val format : Format.formatter -> t -> unit - val pos : t -> Pos.t - val equal : t -> t -> bool - val compare : t -> t -> int - val of_string : string * Pos.t -> t - - module Set : Set.S with type elt = t - module Map : Map.S with type key = t -end +module Module : Id with type info = MarkedString.info module Path : sig type t = Module.t list @@ -94,7 +82,7 @@ module Path : sig end (** Same as [Gen] but also registers path information *) -module Gen_qualified (S : Style) () : sig +module Gen_qualified (_ : Style) () : sig include Id with type info = Path.t * MarkedString.info val fresh : Path.t -> MarkedString.info -> t diff --git a/compiler/catala_web_interpreter.ml b/compiler/catala_web_interpreter.ml index 3a62e70d..3816b248 100644 --- a/compiler/catala_web_interpreter.ml +++ b/compiler/catala_web_interpreter.ml @@ -23,10 +23,10 @@ let () = ~input_src:(Contents (contents, "-inline-")) ~language:(Some language) ~debug:false ~color:Never ~trace () in - let prg, ctx, _type_order = + let prg, _type_order = Passes.dcalc options ~includes:[] ~optimize:false ~check_invariants:false ~typed:Shared_ast.Expr.typed in Shared_ast.Interpreter.interpret_program_dcalc prg - (Commands.get_scope_uid ctx scope) + (Commands.get_scope_uid prg.decl_ctx scope) end) diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index c6f45961..a6c4e415 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -47,15 +47,10 @@ type 'm scope_sig_ctx = { (** Mapping between the input scope variables and the input struct fields. *) } -type 'm scope_sigs_ctx = { - scope_sigs : 'm scope_sig_ctx ScopeName.Map.t; - scope_sigs_modules : 'm scope_sigs_ctx ModuleName.Map.t; -} - type 'm ctx = { decl_ctx : decl_ctx; scope_name : ScopeName.t option; - scopes_parameters : 'm scope_sigs_ctx; + scopes_parameters : 'm scope_sig_ctx ScopeName.Map.t; toplevel_vars : ('m Ast.expr Var.t * naked_typ) TopdefName.Map.t; scope_vars : ('m Ast.expr Var.t * naked_typ * Desugared.Ast.io) ScopeVar.Map.t; @@ -77,14 +72,6 @@ let pos_mark_mk (type a m) (e : (a, m) gexpr) : let pos_mark_as e = pos_mark (Mark.get e) in pos_mark, pos_mark_as -let module_scope_sig scope_sig_ctx scope = - let ssctx = - List.fold_left - (fun ssctx m -> ModuleName.Map.find m ssctx.scope_sigs_modules) - scope_sig_ctx (ScopeName.path scope) - in - ScopeName.Map.find scope ssctx.scope_sigs - let merge_defaults ~(is_func : bool) (caller : (dcalc, 'm) boxed_gexpr) @@ -261,7 +248,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : Expr.ematch ~e:e1 ~name ~cases:d_cases m | EScopeCall { scope; args } -> let pos = Expr.mark_pos m in - let sc_sig = module_scope_sig ctx.scopes_parameters scope in + let sc_sig = ScopeName.Map.find scope ctx.scopes_parameters in let in_var_map = ScopeVar.Map.merge (fun var_name (str_field : scope_input_var_ctx option) expr -> @@ -522,10 +509,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : |> SubScopeName.Map.find (Mark.remove alias) |> retrieve_in_and_out_typ_or_any var | ELocation (ToplevelVar { name }) -> ( - let decl_ctx = - Program.module_ctx ctx.decl_ctx (TopdefName.path (Mark.remove name)) - in - let typ = TopdefName.Map.find (Mark.remove name) decl_ctx.ctx_topdefs in + let typ = TopdefName.Map.find (Mark.remove name) ctx.decl_ctx.ctx_topdefs in match Mark.remove typ with | TArrow (tin, (tout, _)) -> List.map Mark.remove tin, tout | _ -> @@ -735,10 +719,9 @@ let translate_rule could be made more specific to avoid this case, but the added complexity didn't seem worth it *) | Call (subname, subindex, m) -> - let subscope_sig = module_scope_sig ctx.scopes_parameters subname in + let subscope_sig = ScopeName.Map.find subname ctx.scopes_parameters in let scope_sig_decl = - ScopeName.Map.find subname - (Program.module_ctx ctx.decl_ctx (ScopeName.path subname)).ctx_scopes + ScopeName.Map.find subname ctx.decl_ctx.ctx_scopes in let all_subscope_vars = subscope_sig.scope_sig_local_vars in let all_subscope_input_vars = @@ -968,7 +951,7 @@ let translate_scope_decl (sigma : 'm Scopelang.Ast.scope_decl) = let sigma_info = ScopeName.get_info sigma.scope_decl_name in let scope_sig = - ScopeName.Map.find sigma.scope_decl_name ctx.scopes_parameters.scope_sigs + ScopeName.Map.find sigma.scope_decl_name ctx.scopes_parameters in let scope_variables = scope_sig.scope_sig_local_vars in let ctx = { ctx with scope_name = Some scope_name } in @@ -1088,8 +1071,8 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = Scopelang.Dependency.get_defs_ordering defs_dependencies in let decl_ctx = prgm.program_ctx in - let sctx : 'm scope_sigs_ctx = - let process_scope_sig scope_name scope = + let scopes_parameters : 'm scope_sig_ctx ScopeName.Map.t = + let process_scope_sig decl_ctx scope_name scope = let scope_path = ScopeName.path scope_name in let scope_ref = if scope_path = [] then @@ -1100,13 +1083,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = (Mark.copy (ScopeName.get_info scope_name) scope_name) in let scope_info = - try - ScopeName.Map.find scope_name - (Program.module_ctx decl_ctx scope_path).ctx_scopes - with ScopeName.Map.Not_found _ -> - Message.raise_spanned_error - (Mark.get (ScopeName.get_info scope_name)) - "Could not find scope %a" ScopeName.format scope_name + ScopeName.Map.find scope_name decl_ctx.ctx_scopes in let scope_sig_in_fields = (* Output fields have already been generated and added to the program @@ -1154,69 +1131,45 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = scope_sig_in_fields; } in - let rec process_modules prg = - { - scope_sigs = - ScopeName.Map.mapi - (fun scope_name (scope_decl, _) -> - process_scope_sig scope_name scope_decl) - prg.Scopelang.Ast.program_scopes; - scope_sigs_modules = - ModuleName.Map.map process_modules prg.Scopelang.Ast.program_modules; - } + let process_scopes scopes = + ScopeName.Map.mapi + (fun scope_name (scope_decl, _) -> + process_scope_sig decl_ctx scope_name scope_decl) + scopes in - { - scope_sigs = - ScopeName.Map.mapi - (fun scope_name (scope_decl, _) -> - process_scope_sig scope_name scope_decl) - prgm.Scopelang.Ast.program_scopes; - scope_sigs_modules = - ModuleName.Map.map process_modules prgm.Scopelang.Ast.program_modules; - } + ModuleName.Map.fold (fun _ s -> + ScopeName.Map.disjoint_union + (process_scopes s)) + prgm.Scopelang.Ast.program_modules + (process_scopes prgm.Scopelang.Ast.program_scopes) in - let add_scope_in_structs scope_sigs structs = + let ctx_structs = ScopeName.Map.fold (fun _ scope_sig_ctx acc -> - let fields = - ScopeVar.Map.fold - (fun _ sivc acc -> - let pos = Mark.get (StructField.get_info sivc.scope_input_name) in - StructField.Map.add sivc.scope_input_name - (sivc.scope_input_typ, pos) - acc) - scope_sig_ctx.scope_sig_in_fields StructField.Map.empty - in - StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc) - scope_sigs.scope_sigs structs + let fields = + ScopeVar.Map.fold + (fun _ sivc acc -> + let pos = Mark.get (StructField.get_info sivc.scope_input_name) in + StructField.Map.add sivc.scope_input_name + (sivc.scope_input_typ, pos) + acc) + scope_sig_ctx.scope_sig_in_fields StructField.Map.empty + in + StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc) + scopes_parameters decl_ctx.ctx_structs in - let rec gather_module_in_structs acc sctx = - (* Expose all added in_structs from submodules at toplevel *) - ModuleName.Map.fold - (fun _ scope_sigs acc -> - add_scope_in_structs scope_sigs - (gather_module_in_structs acc scope_sigs.scope_sigs_modules)) - sctx acc + let decl_ctx = { decl_ctx with ctx_structs } in + let toplevel_vars = + TopdefName.Map.mapi + (fun name (_, ty) -> + Var.make (Mark.remove (TopdefName.get_info name)), Mark.remove ty) + prgm.Scopelang.Ast.program_topdefs in - let decl_ctx = - { - decl_ctx with - ctx_structs = - add_scope_in_structs sctx - (gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules); - } - in - let top_ctx = - let toplevel_vars = - TopdefName.Map.mapi - (fun name (_, ty) -> - Var.make (Mark.remove (TopdefName.get_info name)), Mark.remove ty) - prgm.Scopelang.Ast.program_topdefs - in + let ctx = { decl_ctx; scope_name = None; - scopes_parameters = sctx; + scopes_parameters; scope_vars = ScopeVar.Map.empty; subscope_vars = SubScopeName.Map.empty; toplevel_vars; @@ -1226,7 +1179,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = (* the resulting expression is the list of definitions of all the scopes, ending with the top-level scope. The decl_ctx is filled in left-to-right order, then the chained scopes aggregated from the right. *) - let rec translate_defs ctx = function + let rec translate_defs = function | [] -> Bindlib.box Nil | def :: next -> let dvar, def = @@ -1245,7 +1198,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = in let scope_var = match - (ScopeName.Map.find scope_name sctx.scope_sigs) + (ScopeName.Map.find scope_name scopes_parameters) .scope_sig_scope_ref with | Local_scope_ref v -> v @@ -1256,13 +1209,13 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = (fun body -> ScopeDef (scope_name, body)) scope_body ) in - let scope_next = translate_defs ctx next in + let scope_next = translate_defs next in let next_bind = Bindlib.bind_var dvar scope_next in Bindlib.box_apply2 (fun item next_bind -> Cons (item, next_bind)) def next_bind in - let items = translate_defs top_ctx defs_ordering in + let items = translate_defs defs_ordering in Expr.Box.assert_closed items; { code_items = Bindlib.unbox items; diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 768cae64..81202338 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -228,12 +228,16 @@ type scope = { scope_meta_assertions : meta_assertion list; } +type modul = { + module_scopes : scope ScopeName.Map.t; + module_topdefs : (expr option * typ) TopdefName.Map.t; +} + type program = { - program_module_name : ModuleName.t option; - program_scopes : scope ScopeName.Map.t; - program_topdefs : (expr option * typ) TopdefName.Map.t; + program_module_name : Ident.t Mark.pos option; program_ctx : decl_ctx; - program_modules : program ModuleName.Map.t; + program_modules : modul ModuleName.Map.t; + program_root : modul; program_lang : Cli.backend_lang; } @@ -299,8 +303,8 @@ let fold_exprs ~(f : 'a -> expr -> 'a) ~(init : 'a) (p : program) : 'a = scope.scope_assertions acc in acc) - p.program_scopes init + p.program_root.module_scopes init in TopdefName.Map.fold (fun _ (e, _) acc -> Option.fold ~none:acc ~some:(f acc) e) - p.program_topdefs acc + p.program_root.module_topdefs acc diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index f4f919bc..a8f24c07 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -93,6 +93,7 @@ type io = { type scope_def = { scope_def_rules : rule RuleName.Map.t; + (** empty outside of the root module *) scope_def_typ : typ; scope_def_parameters : (Uid.MarkedString.info * Shared_ast.typ) list Mark.pos option; @@ -108,16 +109,22 @@ type scope = { scope_uid : ScopeName.t; scope_defs : scope_def ScopeDef.Map.t; scope_assertions : assertion AssertionName.Map.t; + (** empty outside of the root module *) scope_options : catala_option Mark.pos list; scope_meta_assertions : meta_assertion list; } +type modul = { + module_scopes : scope ScopeName.Map.t; + module_topdefs : (expr option * typ) TopdefName.Map.t; + (** the expr is [None] outside of the root module *) +} + type program = { - program_module_name : ModuleName.t option; - program_scopes : scope ScopeName.Map.t; - program_topdefs : (expr option * typ) TopdefName.Map.t; + program_module_name : Ident.t Mark.pos option; program_ctx : decl_ctx; - program_modules : program ModuleName.Map.t; + program_modules : modul ModuleName.Map.t; (** Contains all submodules of the program, in a flattened structure *) + program_root : modul; program_lang : Cli.backend_lang; } diff --git a/compiler/desugared/disambiguate.ml b/compiler/desugared/disambiguate.ml index ea1abffe..b91c5934 100644 --- a/compiler/desugared/disambiguate.ml +++ b/compiler/desugared/disambiguate.ml @@ -64,53 +64,45 @@ let scope ctx env scope = let program prg = (* Caution: this environment building code is very similar to that in scopelang/ast.ml. Any edits should probably be reflected. *) - let base_typing_env prg = - let env = Typing.Env.empty prg.program_ctx in - let env = - TopdefName.Map.fold - (fun name (_e, ty) env -> Typing.Env.add_toplevel_var name ty env) - prg.program_topdefs env - in - let env = - ScopeName.Map.fold - (fun scope_name scope env -> - let vars = - ScopeDef.Map.fold - (fun var def vars -> + let env = Typing.Env.empty prg.program_ctx in + let env = + TopdefName.Map.fold + (fun name ty env -> Typing.Env.add_toplevel_var name ty env) + prg.program_ctx.ctx_topdefs env + in + let env = + ScopeName.Map.fold + (fun scope_name _info env -> + let modul = + List.fold_left + (fun _ m -> ModuleName.Map.find m prg.program_modules) + prg.program_root (ScopeName.path scope_name) + in + let scope = ScopeName.Map.find scope_name modul.module_scopes in + let vars = + ScopeDef.Map.fold + (fun var def vars -> match var with | Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars | SubScopeVar _ -> vars) - scope.scope_defs ScopeVar.Map.empty - in + scope.scope_defs ScopeVar.Map.empty + in (* at this stage, rule resolution and the corresponding encapsulation into default terms hasn't taken place, so input and output variables don't need different typing *) Typing.Env.add_scope scope_name ~vars ~in_vars:vars env) - prg.program_scopes env - in - env + prg.program_ctx.ctx_scopes env in - let rec build_typing_env prg = - ModuleName.Map.fold - (fun modname prg -> - Typing.Env.add_module modname ~module_env:(build_typing_env prg)) - prg.program_modules (base_typing_env prg) - in - let env = - ModuleName.Map.fold - (fun modname prg -> - Typing.Env.add_module modname ~module_env:(build_typing_env prg)) - prg.program_modules (base_typing_env prg) - in - let program_topdefs = + let module_topdefs = TopdefName.Map.map (function | Some e, ty -> Some (Expr.unbox (expr prg.program_ctx env (Expr.box e))), ty | None, ty -> None, ty) - prg.program_topdefs + prg.program_root.module_topdefs in - let program_scopes = - ScopeName.Map.map (scope prg.program_ctx env) prg.program_scopes + let module_scopes = + ScopeName.Map.map (scope prg.program_ctx env) + prg.program_root.module_scopes in - { prg with program_topdefs; program_scopes } + { prg with program_root = { module_topdefs; module_scopes } } diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index dda8c3d7..00cac405 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -123,7 +123,7 @@ let translate_unop (op : S.unop) pos : Ast.expr boxed = let raise_error_cons_not_found (ctxt : Name_resolution.context) (constructor : string Mark.pos) = - let constructors = Ident.Map.keys ctxt.constructor_idmap in + let constructors = Ident.Map.keys ctxt.local.constructor_idmap in let closest_constructors = Suggestions.suggestion_minimum_levenshtein_distance_association constructors (Mark.remove constructor) @@ -146,7 +146,7 @@ let rec disambiguate_constructor "The deep pattern matching syntactic sugar is not yet supported" in let possible_c_uids = - try Ident.Map.find (Mark.remove constructor) ctxt.constructor_idmap + try Ident.Map.find (Mark.remove constructor) ctxt.local.constructor_idmap with Ident.Map.Not_found _ -> raise_error_cons_not_found ctxt constructor in match path with @@ -168,17 +168,13 @@ let rec disambiguate_constructor with EnumName.Map.Not_found _ -> Message.raise_spanned_error pos "Enum %s does not contain case %s" (Mark.remove enum) (Mark.remove constructor)) - | modname :: path -> ( - let modname = ModuleName.of_string modname in - match ModuleName.Map.find_opt modname ctxt.modules with - | None -> - Message.raise_spanned_error (ModuleName.pos modname) - "Module \"%a\" not found" ModuleName.format modname - | Some ctxt -> - let constructor = - List.map (Mark.map (fun (_, c) -> path, c)) constructor0 - in - disambiguate_constructor ctxt constructor pos) + | mod_id :: path -> + let constructor = + List.map (Mark.map (fun (_, c) -> path, c)) constructor0 + in + disambiguate_constructor + (Name_resolution.get_module_ctx ctxt mod_id) + constructor pos let int100 = Runtime.integer_of_int 100 let rat100 = Runtime.decimal_of_integer int100 @@ -370,7 +366,7 @@ let rec translate_expr (* Note: allowing access to a global variable with the same name as a subscope is disputable, but I see no good reason to forbid it either *) | None -> ( - match Ident.Map.find_opt x ctxt.topdefs with + match Ident.Map.find_opt x ctxt.local.topdefs with | Some v -> Expr.elocation (ToplevelVar { name = v, Mark.get (TopdefName.get_info v) }) @@ -380,7 +376,7 @@ let rec translate_expr "for a local, scope-wide or global variable" (x, pos)))) | Ident (path, name) -> ( let ctxt = Name_resolution.module_ctx ctxt path in - match Ident.Map.find_opt (Mark.remove name) ctxt.topdefs with + match Ident.Map.find_opt (Mark.remove name) ctxt.local.topdefs with | Some v -> Expr.elocation (ToplevelVar { name = v, Mark.get (TopdefName.get_info v) }) @@ -415,13 +411,8 @@ let rec translate_expr let rec get_str ctxt = function | [] -> None | [c] -> Some (Name_resolution.get_struct ctxt c) - | modname :: path -> ( - let modname = ModuleName.of_string modname in - match ModuleName.Map.find_opt modname ctxt.modules with - | None -> - Message.raise_spanned_error (ModuleName.pos modname) - "Module \"%a\" not found" ModuleName.format modname - | Some ctxt -> get_str ctxt path) + | mod_id :: path -> + get_str (Name_resolution.get_module_ctx ctxt mod_id) path in Expr.edstructaccess ~e ~field:(Mark.remove x) ~name_opt:(get_str ctxt path) emark) @@ -478,7 +469,7 @@ let rec translate_expr | StructLit (((path, s_name), _), fields) -> let ctxt = Name_resolution.module_ctx ctxt path in let s_uid = - match Ident.Map.find_opt (Mark.remove s_name) ctxt.typedefs with + match Ident.Map.find_opt (Mark.remove s_name) ctxt.local.typedefs with | Some (Name_resolution.TStruct s_uid) -> s_uid | _ -> Message.raise_spanned_error (Mark.get s_name) @@ -490,7 +481,7 @@ let rec translate_expr let f_uid = try StructName.Map.find s_uid - (Ident.Map.find (Mark.remove f_name) ctxt.field_idmap) + (Ident.Map.find (Mark.remove f_name) ctxt.local.field_idmap) with StructName.Map.Not_found _ | Ident.Map.Not_found _ -> Message.raise_spanned_error (Mark.get f_name) "This identifier should refer to a field of struct %s" @@ -518,7 +509,7 @@ let rec translate_expr Expr.estruct ~name:s_uid ~fields:s_fields emark | EnumInject (((path, (constructor, pos_constructor)), _), payload) -> ( let get_possible_c_uids ctxt = - try Ident.Map.find constructor ctxt.Name_resolution.constructor_idmap + try Ident.Map.find constructor ctxt.Name_resolution.local.constructor_idmap with Ident.Map.Not_found _ -> raise_error_cons_not_found ctxt (constructor, pos_constructor) in @@ -1027,7 +1018,7 @@ let process_def (ctxt : Name_resolution.context) (prgm : Ast.program) (def : S.definition) : Ast.program = - let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in + let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_root.module_scopes in let scope_ctxt = ScopeName.Map.find scope_uid ctxt.scopes in let def_key = Name_resolution.get_def_key @@ -1091,10 +1082,13 @@ let process_def scope_defs = Ast.ScopeDef.Map.add def_key scope_def scope.scope_defs; } in + let module_scopes = + ScopeName.Map.add scope_uid scope_updated + prgm.program_root.module_scopes + in { prgm with - program_scopes = - ScopeName.Map.add scope_uid scope_updated prgm.program_scopes; + program_root = { prgm.program_root with module_scopes } } (** Translates a {!type: S.rule} from the surface language *) @@ -1114,7 +1108,7 @@ let process_assert (ctxt : Name_resolution.context) (prgm : Ast.program) (ass : S.assertion) : Ast.program = - let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in + let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_root.module_scopes in let ass = translate_expr (Some scope_uid) None ctxt Ident.Map.empty (match ass.S.assertion_condition with @@ -1146,9 +1140,11 @@ let process_assert scope.scope_assertions; } in + let module_scopes = ScopeName.Map.add scope_uid new_scope prgm.program_root.module_scopes + in { prgm with - program_scopes = ScopeName.Map.add scope_uid new_scope prgm.program_scopes; + program_root = { prgm.program_root with module_scopes } } (** Translates a surface definition, rule or assertion *) @@ -1167,7 +1163,7 @@ let process_scope_use_item | S.Assertion ass -> process_assert precond scope ctxt prgm ass | S.DateRounding (r, _) -> let scope_uid = scope in - let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in + let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_root.module_scopes in let r = match r with | S.Increasing -> Ast.Increasing @@ -1192,9 +1188,10 @@ let process_scope_use_item Mark.copy item (Ast.DateRounding r) :: scope.scope_options; } in + let module_scopes = ScopeName.Map.add scope_uid new_scope prgm.program_root.module_scopes in { prgm with - program_scopes = ScopeName.Map.add scope_uid new_scope prgm.program_scopes; + program_root = { prgm.program_root with module_scopes } } | _ -> prgm @@ -1254,7 +1251,7 @@ let process_scope_use let scope_uid = Name_resolution.get_scope ctxt use.scope_use_name in (* Make sure the scope exists *) let prgm = - match ScopeName.Map.find_opt scope_uid prgm.program_scopes with + match ScopeName.Map.find_opt scope_uid prgm.program_root.module_scopes with | Some _ -> prgm | None -> assert false (* should not happen *) @@ -1270,7 +1267,7 @@ let process_topdef (prgm : Ast.program) (def : S.top_def) : Ast.program = let id = - Ident.Map.find (Mark.remove def.S.topdef_name) ctxt.Name_resolution.topdefs + Ident.Map.find (Mark.remove def.S.topdef_name) ctxt.Name_resolution.local.topdefs in let translate_typ t = Name_resolution.process_type ctxt t in let translate_tbase (tbase, m) = translate_typ (Base tbase, m) in @@ -1300,7 +1297,7 @@ let process_topdef in Some (Expr.unbox_closed e) in - let program_topdefs = + let module_topdefs = TopdefName.Map.update id (fun def0 -> match def0, expr_opt with @@ -1318,9 +1315,9 @@ let process_topdef | Some _, Some _ -> err "Multiple definitions" | Some e, None -> Some (Some e, typ) | None, Some e -> Some (Some e, ty0))) - prgm.Ast.program_topdefs + prgm.Ast.program_root.module_topdefs in - { prgm with Ast.program_topdefs } + { prgm with program_root = { prgm.program_root with module_topdefs } } let attribute_to_io (attr : S.scope_decl_context_io) : Ast.io = { @@ -1337,13 +1334,13 @@ let attribute_to_io (attr : S.scope_decl_context_io) : Ast.io = let init_scope_defs (ctxt : Name_resolution.context) - (scope_idmap : Name_resolution.scope_var_or_subscope Ident.Map.t) : + (scope_idmap : scope_var_or_subscope Ident.Map.t) : Ast.scope_def Ast.ScopeDef.Map.t = (* Initializing the definitions of all scopes and subscope vars, with no rules yet inside *) let add_def _ v scope_def_map = match v with - | Name_resolution.ScopeVar v -> ( + | ScopeVar v -> ( let v_sig = ScopeVar.Map.find v ctxt.Name_resolution.var_typs in match v_sig.var_sig_states_list with | [] -> @@ -1389,19 +1386,20 @@ let init_scope_defs (scope_def_map, 0) states in scope_def) - | Name_resolution.SubScope (v0, subscope_uid) -> + | SubScope (v0, subscope_uid) -> let sub_scope_def = Name_resolution.get_scope_context ctxt subscope_uid in let ctxt = List.fold_left - (fun ctx m -> ModuleName.Map.find m ctx.Name_resolution.modules) + (fun ctx m -> + { ctxt with local = ModuleName.Map.find m ctx.Name_resolution.modules }) ctxt (ScopeName.path subscope_uid) in Ident.Map.fold (fun _ v scope_def_map -> match v with - | Name_resolution.SubScope _ -> scope_def_map - | Name_resolution.ScopeVar v -> + | SubScope _ -> scope_def_map + | ScopeVar v -> (* TODO: shouldn't we ignore internal variables too at this point ? *) let v_sig = ScopeVar.Map.find v ctxt.Name_resolution.var_typs in @@ -1424,91 +1422,110 @@ let init_scope_defs (** Main function of this module *) let translate_program (ctxt : Name_resolution.context) (surface : S.program) : Ast.program = - let top_ctx = ctxt in - let desugared = - let get_program_scopes ctxt = - ScopeName.Map.mapi - (fun s_uid s_context -> - let scope_vars = - Ident.Map.fold - (fun _ v acc -> - match v with - | Name_resolution.SubScope _ -> acc - | Name_resolution.ScopeVar v -> ( - let v_sig = - ScopeVar.Map.find v ctxt.Name_resolution.var_typs - in - match v_sig.Name_resolution.var_sig_states_list with - | [] -> ScopeVar.Map.add v Ast.WholeVar acc - | states -> ScopeVar.Map.add v (Ast.States states) acc)) - s_context.Name_resolution.var_idmap ScopeVar.Map.empty - in - let scope_sub_scopes = - Ident.Map.fold - (fun _ v acc -> - match v with - | Name_resolution.ScopeVar _ -> acc - | Name_resolution.SubScope (sub_var, sub_scope) -> - SubScopeName.Map.add sub_var sub_scope acc) - s_context.Name_resolution.var_idmap SubScopeName.Map.empty - in - { - Ast.scope_vars; - scope_sub_scopes; - scope_defs = init_scope_defs top_ctx s_context.var_idmap; - scope_assertions = Ast.AssertionName.Map.empty; - scope_meta_assertions = []; - scope_options = []; - scope_uid = s_uid; - }) - ctxt.Name_resolution.scopes + let get_scope s_uid = + let s_context = ScopeName.Map.find s_uid ctxt.scopes in + let scope_vars = + Ident.Map.fold + (fun _ v acc -> + match v with + | SubScope _ -> acc + | ScopeVar v -> ( + let v_sig = + ScopeVar.Map.find v ctxt.Name_resolution.var_typs + in + match v_sig.Name_resolution.var_sig_states_list with + | [] -> ScopeVar.Map.add v Ast.WholeVar acc + | states -> ScopeVar.Map.add v (Ast.States states) acc)) + s_context.Name_resolution.var_idmap ScopeVar.Map.empty in - let rec make_ctx ctxt = - let submodules = - ModuleName.Map.map make_ctx ctxt.Name_resolution.modules + let scope_sub_scopes = + Ident.Map.fold + (fun _ v acc -> + match v with + | ScopeVar _ -> acc + | SubScope (sub_var, sub_scope) -> + SubScopeName.Map.add sub_var sub_scope acc) + s_context.Name_resolution.var_idmap SubScopeName.Map.empty + in + { + Ast.scope_vars; + scope_sub_scopes; + scope_defs = init_scope_defs ctxt s_context.var_idmap; + scope_assertions = Ast.AssertionName.Map.empty; + scope_meta_assertions = []; + scope_options = []; + scope_uid = s_uid; + } + in + let get_scopes mctx = + Ident.Map.fold (fun _ tydef acc -> match tydef with + | Name_resolution.TScope (s_uid, _) -> + ScopeName.Map.add s_uid (get_scope s_uid) acc + | _ -> acc) + mctx.Name_resolution.typedefs ScopeName.Map.empty; + in + let program_modules = + ModuleName.Map.map (fun mctx -> + { Ast.module_scopes = get_scopes mctx; + Ast.module_topdefs = + Ident.Map.fold (fun _ name acc -> + TopdefName.Map.add name + (None, + TopdefName.Map.find name ctxt.Name_resolution.topdef_types) + acc; + ) + mctx.topdefs TopdefName.Map.empty + }) + ctxt.modules + in + let program_ctx = + let open Name_resolution in + let ctx_scopes mctx acc = + Ident.Map.fold (fun _ tydef acc -> + match tydef with + | TScope (s_uid, info) -> + ScopeName.Map.add s_uid info acc + | _ -> acc) + mctx.Name_resolution.typedefs acc + in + let ctx_modules = + let rec aux mctx = + Ident.Map.fold (fun _ m (M acc) -> + let sub = aux (ModuleName.Map.find m ctxt.modules) in + M (ModuleName.Map.add m sub acc)) + mctx.used_modules (M ModuleName.Map.empty) in - { - Ast.program_lang = surface.program_lang; - Ast.program_module_name = - Option.map ModuleName.of_string - surface.Surface.Ast.program_module_name; - Ast.program_ctx = - { - (* After name resolution, type definitions (structs and enums) are - exposed at toplevel for easier lookup *) - ctx_structs = - ModuleName.Map.fold - (fun _ prg acc -> - StructName.Map.union - (fun _ _ _ -> assert false) - acc prg.Ast.program_ctx.ctx_structs) - submodules ctxt.Name_resolution.structs; - ctx_enums = - ModuleName.Map.fold - (fun _ prg acc -> - EnumName.Map.union - (fun _ _ _ -> assert false) - acc prg.Ast.program_ctx.ctx_enums) - submodules ctxt.Name_resolution.enums; - ctx_scopes = - Ident.Map.fold - (fun _ def acc -> - match def with - | Name_resolution.TScope (scope, scope_info) -> - ScopeName.Map.add scope scope_info acc - | _ -> acc) - ctxt.Name_resolution.typedefs ScopeName.Map.empty; - ctx_struct_fields = ctxt.Name_resolution.field_idmap; - ctx_topdefs = ctxt.Name_resolution.topdef_types; - ctx_modules = - ModuleName.Map.map (fun s -> s.Ast.program_ctx) submodules; - }; - Ast.program_topdefs = TopdefName.Map.empty; - Ast.program_scopes = get_program_scopes ctxt; - Ast.program_modules = submodules; - } + aux ctxt.local in - make_ctx ctxt + { + ctx_structs = ctxt.structs; + ctx_enums = ctxt.enums; + ctx_scopes = + ModuleName.Map.fold (fun _ -> ctx_scopes) + ctxt.modules + (ctx_scopes ctxt.local ScopeName.Map.empty); + ctx_topdefs = ctxt.topdef_types; + ctx_struct_fields = ctxt.local.field_idmap; + ctx_enum_constrs = ctxt.local.constructor_idmap; + ctx_scope_index = + Ident.Map.filter_map (fun _ -> function + | Name_resolution.TScope (s, _) -> Some s + | _ -> None) + ctxt.local.typedefs; + ctx_modules; + } + in + let desugared = + { + Ast.program_lang = surface.program_lang; + Ast.program_module_name = surface.Surface.Ast.program_module_name; + Ast.program_modules; + Ast.program_ctx; + Ast.program_root = { + Ast.module_scopes = get_scopes ctxt.Name_resolution.local; + Ast.module_topdefs = TopdefName.Map.empty; + }; + } in let process_code_block ctxt prgm block = List.fold_left @@ -1527,29 +1544,6 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : (fun prgm child -> process_structure prgm child) prgm children | S.CodeBlock (block, _, _) -> process_code_block ctxt prgm block - | S.ModuleDef ((name, pos) as mname) -> - let file = Filename.basename (Pos.get_file pos) in - if not File.(equal name (Filename.remove_extension file)) then - Message.raise_spanned_error pos - "Module declared as %a, which does not match the file name %a" - ModuleName.format - (ModuleName.of_string mname) - File.format file - else prgm - | S.LawInclude _ | S.LawText _ | S.ModuleUse _ -> prgm - in - let desugared = - List.fold_left - (fun acc (id, intf) -> - let id = ModuleName.of_string id in - let modul = ModuleName.Map.find id acc.Ast.program_modules in - let modul = - process_code_block (ModuleName.Map.find id ctxt.modules) modul intf - in - { - acc with - program_modules = ModuleName.Map.add id modul acc.program_modules; - }) - desugared surface.S.program_modules + | S.ModuleDef _ | S.LawInclude _ | S.LawText _ | S.ModuleUse _ -> prgm in List.fold_left process_structure desugared surface.S.program_items diff --git a/compiler/desugared/linting.ml b/compiler/desugared/linting.ml index 8cfd054b..264fe784 100644 --- a/compiler/desugared/linting.ml +++ b/compiler/desugared/linting.ml @@ -39,7 +39,7 @@ let detect_empty_definitions (p : program) : unit = defined; did you forget something?" ScopeName.format scope_name Ast.ScopeDef.format scope_def_key) scope.scope_defs) - p.program_scopes + p.program_root.module_scopes (* To detect rules that have the same justification and conclusion, we create a set data structure with an appropriate comparison function *) @@ -97,7 +97,7 @@ let detect_identical_rules (p : program) : unit = else "definitions")) rules_seen) scope.scope_defs) - p.program_scopes + p.program_root.module_scopes let detect_unused_struct_fields (p : program) : unit = (* TODO: this analysis should be finer grained: a false negative is if the @@ -111,14 +111,9 @@ let detect_unused_struct_fields (p : program) : unit = ~f:(fun struct_fields_used e -> let rec structs_fields_used_expr e struct_fields_used = match Mark.remove e with - | EDStructAccess { name_opt = Some name; e = e_struct; field } -> - let ctx = - Program.module_ctx p.program_ctx (StructName.path name) - in - let field = - StructName.Map.find name - (Ident.Map.find field ctx.ctx_struct_fields) - in + | EDStructAccess _ -> assert false + (* linting must be performed after disambiguation *) + | EStructAccess { e = e_struct; field; _ } -> StructField.Set.add field (structs_fields_used_expr e_struct struct_fields_used) | EStruct { name = _; fields } -> @@ -284,7 +279,7 @@ let detect_dead_code (p : program) : unit = emit_unused_warning ()) states) scope.scope_vars) - p.program_scopes + p.program_root.module_scopes let lint_program (p : program) : unit = detect_empty_definitions p; diff --git a/compiler/desugared/name_resolution.ml b/compiler/desugared/name_resolution.ml index 5381b0e6..7cea74cb 100644 --- a/compiler/desugared/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -30,10 +30,6 @@ type scope_def_context = { label_idmap : LabelName.t Ident.Map.t; } -type scope_var_or_subscope = - | ScopeVar of ScopeVar.t - | SubScope of SubScopeName.t * ScopeName.t - type scope_context = { var_idmap : scope_var_or_subscope Ident.Map.t; (** All variables, including scope variables and subscopes *) @@ -67,7 +63,7 @@ type typedef = | TEnum of EnumName.t | TScope of ScopeName.t * scope_info (** Implicitly defined output struct *) -type context = { +type module_context = { path : Uid.Path.t; typedefs : typedef Ident.Map.t; (** Gathers the names of the scopes, structs and enums *) @@ -77,17 +73,24 @@ type context = { constructor_idmap : EnumConstructor.t EnumName.Map.t Ident.Map.t; (** The names of the enum constructors. Constructor names can be shared between different enums *) - scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) topdefs : TopdefName.t Ident.Map.t; (** Global definitions *) + used_modules : ModuleName.t Ident.Map.t; +} +(** Context for name resolution, valid within a given module *) + +type context = { + scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) topdef_types : typ TopdefName.Map.t; structs : struct_context StructName.Map.t; (** For each struct, its context *) enums : enum_context EnumName.Map.t; (** For each enum, its context *) var_typs : var_sig ScopeVar.Map.t; (** The signatures of each scope variable declared *) - modules : context ModuleName.Map.t; + modules : module_context ModuleName.Map.t; + local : module_context; + (** Module being currently analysed (at the end: the root module) *) } -(** Main context used throughout {!module: Surface.Desugaring} *) +(** Global context used throughout {!module: Surface.Desugaring} *) (** {1 Helpers} *) @@ -114,16 +117,6 @@ let get_var_io (ctxt : context) (uid : ScopeVar.t) : (ScopeVar.Map.find uid ctxt.var_typs).var_sig_io let get_scope_context (ctxt : context) (scope : ScopeName.t) : scope_context = - let rec remove_common_prefix curpath scpath = - match curpath, scpath with - | m1 :: cp, m2 :: sp when ModuleName.equal m1 m2 -> - remove_common_prefix cp sp - | _ -> scpath - in - let path = remove_common_prefix ctxt.path (ScopeName.path scope) in - let ctxt = - List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.modules) ctxt path - in ScopeName.Map.find scope ctxt.scopes (** Get the variable uid inside the scope given in argument *) @@ -196,7 +189,7 @@ let is_def_cond (ctxt : context) (def : Ast.ScopeDef.t) : bool = is_var_cond ctxt x let get_enum ctxt id = - match Ident.Map.find (Mark.remove id) ctxt.typedefs with + match Ident.Map.find (Mark.remove id) ctxt.local.typedefs with | TEnum id -> id | TStruct sid -> Message.raise_multispanned_error @@ -217,7 +210,7 @@ let get_enum ctxt id = (Mark.remove id) let get_struct ctxt id = - match Ident.Map.find (Mark.remove id) ctxt.typedefs with + match Ident.Map.find (Mark.remove id) ctxt.local.typedefs with | TStruct id | TScope (_, { out_struct_name = id; _ }) -> id | TEnum eid -> Message.raise_multispanned_error @@ -231,7 +224,7 @@ let get_struct ctxt id = (Mark.remove id) let get_scope ctxt id = - match Ident.Map.find (Mark.remove id) ctxt.typedefs with + match Ident.Map.find (Mark.remove id) ctxt.local.typedefs with | TScope (id, _) -> id | TEnum eid -> Message.raise_multispanned_error @@ -251,16 +244,21 @@ let get_scope ctxt id = Message.raise_spanned_error (Mark.get id) "No scope named %s found" (Mark.remove id) -let rec module_ctx ctxt path = - match path with +let get_modname ctxt (id, pos) = + match Ident.Map.find_opt id ctxt.local.used_modules with + | None -> + Message.raise_spanned_error pos "Module \"@{%s@}\" not found" id + | Some modname -> modname + +let get_module_ctx ctxt id = + let modname = get_modname ctxt id in + { ctxt with local = ModuleName.Map.find modname ctxt.modules } + +let rec module_ctx ctxt path0 = + match path0 with | [] -> ctxt - | modname :: path -> ( - let modname = ModuleName.of_string modname in - match ModuleName.Map.find_opt modname ctxt.modules with - | None -> - Message.raise_spanned_error (ModuleName.pos modname) - "Module \"%a\" not found" ModuleName.format modname - | Some ctxt -> module_ctx ctxt path) + | mod_id :: path -> + module_ctx (get_module_ctx ctxt mod_id) path (** {1 Declarations pass} *) @@ -328,7 +326,7 @@ let rec process_base_typ | Surface.Ast.Boolean -> TLit TBool, typ_pos | Surface.Ast.Text -> raise_unsupported_feature "text type" typ_pos | Surface.Ast.Named ([], (ident, _pos)) -> ( - match Ident.Map.find_opt ident ctxt.typedefs with + match Ident.Map.find_opt ident ctxt.local.typedefs with | Some (TStruct s_uid) -> TStruct s_uid, typ_pos | Some (TEnum e_uid) -> TEnum e_uid, typ_pos | Some (TScope (_, scope_str)) -> @@ -338,15 +336,14 @@ let rec process_base_typ "Unknown type @{\"%s\"@}, not a struct or enum previously \ declared" ident) - | Surface.Ast.Named (modul :: path, id) -> ( - let modul = ModuleName.of_string modul in - match ModuleName.Map.find_opt modul ctxt.modules with + | Surface.Ast.Named ((modul, mpos) :: path, id) -> ( + match Ident.Map.find_opt modul ctxt.local.used_modules with | None -> - Message.raise_spanned_error (ModuleName.pos modul) - "This refers to module %a, which was not found" ModuleName.format - modul - | Some mod_ctxt -> - process_base_typ mod_ctxt + Message.raise_spanned_error mpos + "This refers to module @{%s@}, which was not found" modul + | Some mname -> + let mod_ctxt = ModuleName.Map.find mname ctxt.modules in + process_base_typ { ctxt with local = mod_ctxt } Surface.Ast.(Data (Primitive (Named (path, id))), typ_pos))) (** Process a type (function or not) *) @@ -449,9 +446,9 @@ let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : List.fold_left (fun ctxt (fdecl, _) -> let f_uid = StructField.fresh fdecl.Surface.Ast.struct_decl_field_name in - let ctxt = + let local = { - ctxt with + ctxt.local with field_idmap = Ident.Map.update (Mark.remove fdecl.Surface.Ast.struct_decl_field_name) @@ -459,26 +456,26 @@ let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : match uids with | None -> Some (StructName.Map.singleton s_uid f_uid) | Some uids -> Some (StructName.Map.add s_uid f_uid uids)) - ctxt.field_idmap; + ctxt.local.field_idmap; } in - { - ctxt with - structs = - StructName.Map.update s_uid - (fun fields -> - match fields with - | None -> - Some - (StructField.Map.singleton f_uid - (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)) - | Some fields -> - Some - (StructField.Map.add f_uid - (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ) - fields)) - ctxt.structs; - }) + let ctxt = { ctxt with local } in + let structs = + StructName.Map.update s_uid + (fun fields -> + match fields with + | None -> + Some + (StructField.Map.singleton f_uid + (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)) + | Some fields -> + Some + (StructField.Map.add f_uid + (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ) + fields)) + ctxt.structs + in + { ctxt with structs }) ctxt sdecl.struct_decl_fields (** Process an enum declaration *) @@ -494,9 +491,9 @@ let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context List.fold_left (fun ctxt (cdecl, cdecl_pos) -> let c_uid = EnumConstructor.fresh cdecl.Surface.Ast.enum_decl_case_name in - let ctxt = + let local = { - ctxt with + ctxt.local with constructor_idmap = Ident.Map.update (Mark.remove cdecl.Surface.Ast.enum_decl_case_name) @@ -504,29 +501,29 @@ let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context match uids with | None -> Some (EnumName.Map.singleton e_uid c_uid) | Some uids -> Some (EnumName.Map.add e_uid c_uid uids)) - ctxt.constructor_idmap; + ctxt.local.constructor_idmap; } in - { - ctxt with - enums = - EnumName.Map.update e_uid - (fun cases -> - let typ = - match cdecl.Surface.Ast.enum_decl_case_typ with - | None -> TLit TUnit, cdecl_pos - | Some typ -> process_type ctxt typ - in - match cases with - | None -> Some (EnumConstructor.Map.singleton c_uid typ) - | Some fields -> Some (EnumConstructor.Map.add c_uid typ fields)) - ctxt.enums; - }) + let ctxt = { ctxt with local } in + let enums = + EnumName.Map.update e_uid + (fun cases -> + let typ = + match cdecl.Surface.Ast.enum_decl_case_typ with + | None -> TLit TUnit, cdecl_pos + | Some typ -> process_type ctxt typ + in + match cases with + | None -> Some (EnumConstructor.Map.singleton c_uid typ) + | Some fields -> Some (EnumConstructor.Map.add c_uid typ fields)) + ctxt.enums + in + { ctxt with enums }) ctxt edecl.enum_decl_cases let process_topdef ctxt def = let uid = - Ident.Map.find (Mark.remove def.Surface.Ast.topdef_name) ctxt.topdefs + Ident.Map.find (Mark.remove def.Surface.Ast.topdef_name) ctxt.local.topdefs in { ctxt with @@ -605,7 +602,7 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : | ScopeVar v -> ( try let field = - StructName.Map.find str (Ident.Map.find id ctxt.field_idmap) + StructName.Map.find str (Ident.Map.find id ctxt.local.field_idmap) in ScopeVar.Map.add v field svmap with StructName.Map.Not_found _ | Ident.Map.Not_found _ -> svmap)) @@ -620,9 +617,9 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : (TScope (scope, { in_struct_name; out_struct_name; out_struct_fields })) | _ -> assert false) - ctxt.typedefs + ctxt.local.typedefs in - { ctxt with typedefs } + { ctxt with local = { ctxt.local with typedefs } } let typedef_info = function | TStruct t -> StructName.get_info t @@ -648,59 +645,61 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : Option.iter (fun use -> raise_already_defined_error (typedef_info use) name pos "scope") - (Ident.Map.find_opt name ctxt.typedefs); - let scope_uid = ScopeName.fresh ctxt.path (name, pos) in - let in_struct_name = StructName.fresh ctxt.path (name ^ "_in", pos) in - let out_struct_name = StructName.fresh ctxt.path (name, pos) in + (Ident.Map.find_opt name ctxt.local.typedefs); + let scope_uid = ScopeName.fresh ctxt.local.path (name, pos) in + let in_struct_name = StructName.fresh ctxt.local.path (name ^ "_in", pos) in + let out_struct_name = StructName.fresh ctxt.local.path (name, pos) in + let typedefs = + Ident.Map.add name + (TScope + ( scope_uid, + { + in_struct_name; + out_struct_name; + out_struct_fields = ScopeVar.Map.empty; + } )) + ctxt.local.typedefs + in + let scopes = + ScopeName.Map.add scope_uid + { + var_idmap = Ident.Map.empty; + scope_defs_contexts = Ast.ScopeDef.Map.empty; + sub_scopes = ScopeName.Set.empty; + } + ctxt.scopes + in { ctxt with - typedefs = - Ident.Map.add name - (TScope - ( scope_uid, - { - in_struct_name; - out_struct_name; - out_struct_fields = ScopeVar.Map.empty; - } )) - ctxt.typedefs; - scopes = - ScopeName.Map.add scope_uid - { - var_idmap = Ident.Map.empty; - scope_defs_contexts = Ast.ScopeDef.Map.empty; - sub_scopes = ScopeName.Set.empty; - } - ctxt.scopes; + local = { ctxt.local with typedefs }; + scopes; } | StructDecl sdecl -> let name, pos = sdecl.struct_decl_name in Option.iter (fun use -> raise_already_defined_error (typedef_info use) name pos "struct") - (Ident.Map.find_opt name ctxt.typedefs); - let s_uid = StructName.fresh ctxt.path sdecl.struct_decl_name in - { - ctxt with - typedefs = - Ident.Map.add - (Mark.remove sdecl.struct_decl_name) - (TStruct s_uid) ctxt.typedefs; - } + (Ident.Map.find_opt name ctxt.local.typedefs); + let s_uid = StructName.fresh ctxt.local.path sdecl.struct_decl_name in + let typedefs = + Ident.Map.add + (Mark.remove sdecl.struct_decl_name) + (TStruct s_uid) ctxt.local.typedefs; + in + { ctxt with local = { ctxt.local with typedefs} } | EnumDecl edecl -> let name, pos = edecl.enum_decl_name in Option.iter (fun use -> raise_already_defined_error (typedef_info use) name pos "enum") - (Ident.Map.find_opt name ctxt.typedefs); - let e_uid = EnumName.fresh ctxt.path edecl.enum_decl_name in - { - ctxt with - typedefs = - Ident.Map.add - (Mark.remove edecl.enum_decl_name) - (TEnum e_uid) ctxt.typedefs; - } + (Ident.Map.find_opt name ctxt.local.typedefs); + let e_uid = EnumName.fresh ctxt.local.path edecl.enum_decl_name in + let typedefs = + Ident.Map.add + (Mark.remove edecl.enum_decl_name) + (TEnum e_uid) ctxt.local.typedefs + in + { ctxt with local = { ctxt.local with typedefs} } | ScopeUse _ -> ctxt | Topdef def -> let name, pos = def.topdef_name in @@ -708,9 +707,10 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : (fun use -> raise_already_defined_error (TopdefName.get_info use) name pos "toplevel definition") - (Ident.Map.find_opt name ctxt.topdefs); - let uid = TopdefName.fresh ctxt.path def.topdef_name in - { ctxt with topdefs = Ident.Map.add name uid ctxt.topdefs } + (Ident.Map.find_opt name ctxt.local.topdefs); + let uid = TopdefName.fresh ctxt.local.path def.topdef_name in + let topdefs = Ident.Map.add name uid ctxt.local.topdefs in + { ctxt with local = { ctxt.local with topdefs } } (** Process a code item that is a declaration *) let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : @@ -918,7 +918,7 @@ let process_scope_use (ctxt : context) (suse : Surface.Ast.scope_use) : context match Ident.Map.find_opt (Mark.remove suse.Surface.Ast.scope_use_name) - ctxt.typedefs + ctxt.local.typedefs with | Some (TScope (sn, _)) -> sn | _ -> @@ -940,83 +940,90 @@ let process_use_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : (** {1 API} *) -let empty_ctxt = - { - path = []; - typedefs = Ident.Map.empty; - scopes = ScopeName.Map.empty; - topdefs = Ident.Map.empty; - topdef_types = TopdefName.Map.empty; - var_typs = ScopeVar.Map.empty; - structs = StructName.Map.empty; - field_idmap = Ident.Map.empty; - enums = EnumName.Map.empty; - constructor_idmap = Ident.Map.empty; - modules = ModuleName.Map.empty; - } +let empty_module_ctxt = { + path = []; + typedefs = Ident.Map.empty; + field_idmap = Ident.Map.empty; + constructor_idmap = Ident.Map.empty; + topdefs = Ident.Map.empty; + used_modules = Ident.Map.empty; +} -let import_module modules (name, intf) = - let mname = ModuleName.of_string name in - let ctxt = { empty_ctxt with modules; path = [mname] } in - let ctxt = List.fold_left process_name_item ctxt intf in - let ctxt = List.fold_left process_decl_item ctxt intf in - let ctxt = { ctxt with modules = empty_ctxt.modules } in - (* No submodules at the moment, a module may use the ones loaded before it, - but doesn't reexport them *) - ModuleName.Map.add mname ctxt modules +let empty_ctxt = { + scopes = ScopeName.Map.empty; + topdef_types = TopdefName.Map.empty; + var_typs = ScopeVar.Map.empty; + structs = StructName.Map.empty; + enums = EnumName.Map.empty; + modules = ModuleName.Map.empty; + local = empty_module_ctxt; +} (** Derive the context from metadata, in one pass over the declarations *) -let form_context (prgm : Surface.Ast.program) : context = - let modules = - List.fold_left import_module ModuleName.Map.empty prgm.program_modules - in - let ctxt = { empty_ctxt with modules } in - let rec gather_var_sigs acc modules = - (* Scope vars from imported modules need to be accessible directly for - definitions through submodules *) - ModuleName.Map.fold - (fun _modname mctx acc -> - let acc = gather_var_sigs acc mctx.modules in - ScopeVar.Map.union (fun _ _ -> assert false) acc mctx.var_typs) - modules acc - in - let ctxt = - { ctxt with var_typs = gather_var_sigs ScopeVar.Map.empty ctxt.modules } +let form_context (surface, mod_uses) surface_modules : context = + let rec process_modules ctxt mod_uses = + (* Recursing on [mod_uses] rather than folding on [modules] ensures a topological traversal. *) + Ident.Map.fold (fun _alias m ctxt -> + match ModuleName.Map.find_opt m ctxt.modules with + | Some _ -> ctxt + | None -> + let intf, mod_uses = ModuleName.Map.find m surface_modules in + let ctxt = process_modules ctxt mod_uses in + let ctxt = { ctxt with + local = { ctxt.local with used_modules = mod_uses; + path = [m] } } in + let ctxt = List.fold_left process_name_item ctxt intf.Surface.Ast.intf_code in + let ctxt = List.fold_left process_decl_item ctxt intf.Surface.Ast.intf_code in + { ctxt with + modules = ModuleName.Map.add m ctxt.local ctxt.modules; + local = empty_module_ctxt } + ) + mod_uses ctxt in + let ctxt = process_modules empty_ctxt mod_uses in + let ctxt = { ctxt with local = { empty_module_ctxt with used_modules = mod_uses } } in let ctxt = List.fold_left (process_law_structure process_name_item) - ctxt prgm.program_items + ctxt surface.Surface.Ast.program_items in let ctxt = List.fold_left (process_law_structure process_decl_item) - ctxt prgm.program_items + ctxt surface.Surface.Ast.program_items in let ctxt = List.fold_left (process_law_structure process_use_item) - ctxt prgm.program_items + ctxt surface.Surface.Ast.program_items in - let rec gather_all_constrs ctxt = - (* Gather struct fields and enum constrs from modules: this helps with - disambiguation *) - let modules, constructor_idmap, field_idmap = - ModuleName.Map.fold - (fun m ctx (mmap, constrs, fields) -> - let ctx = gather_all_constrs ctx in - ( ModuleName.Map.add m ctx mmap, - Ident.Map.union - (fun _ enu1 enu2 -> - Some (EnumName.Map.union (fun _ _ -> assert false) enu1 enu2)) - constrs ctx.constructor_idmap, - Ident.Map.union - (fun _ str1 str2 -> - Some (StructName.Map.union (fun _ _ -> assert false) str1 str2)) - fields ctx.field_idmap )) - ctxt.modules - (ModuleName.Map.empty, ctxt.constructor_idmap, ctxt.field_idmap) - in - { ctxt with modules; constructor_idmap; field_idmap } + (* Gather struct fields and enum constrs from direct modules: this helps with + disambiguation. This is only done towards the root context, because submodules are only interfaces which don't need disambiguation ; and transitive dependencies shouldn't be visible here. *) + let sub_constructor_idmap, sub_field_idmap = + Ident.Map.fold (fun _ m (cmap, fmap) -> + let lctx = ModuleName.Map.find m ctxt.modules in + let cmap = + Ident.Map.union + (fun _ enu1 enu2 -> Some (EnumName.Map.disjoint_union enu1 enu2)) + cmap lctx.constructor_idmap + in + let fmap = + Ident.Map.union + (fun _ str1 str2 -> Some (StructName.Map.disjoint_union str1 str2)) + fmap lctx.field_idmap + in + cmap, fmap) + mod_uses (Ident.Map.empty, Ident.Map.empty) in - gather_all_constrs ctxt + { ctxt with + local = + { ctxt.local with + (* In the root context, don't disambiguate on submodules structs/enums when there is a conflict *) + constructor_idmap = + Ident.Map.union (fun _ base _ -> Some base) + ctxt.local.constructor_idmap sub_constructor_idmap; + field_idmap = + Ident.Map.union (fun _ base _ -> Some base) + ctxt.local.field_idmap sub_field_idmap; + } + } diff --git a/compiler/desugared/name_resolution.mli b/compiler/desugared/name_resolution.mli index 86cc6c21..ad4ad964 100644 --- a/compiler/desugared/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -30,10 +30,6 @@ type scope_def_context = { label_idmap : LabelName.t Ident.Map.t; } -type scope_var_or_subscope = - | ScopeVar of ScopeVar.t - | SubScope of SubScopeName.t * ScopeName.t - type scope_context = { var_idmap : scope_var_or_subscope Ident.Map.t; (** All variables, including scope variables and subscopes *) @@ -67,19 +63,24 @@ type typedef = | TEnum of EnumName.t | TScope of ScopeName.t * scope_info (** Implicitly defined output struct *) -type context = { - path : ModuleName.t list; - (** The current path being processed. Used for generating the Uids. *) +type module_context = { + path : Uid.Path.t; + (** The current path being processed. Used for generating the Uids. *) typedefs : typedef Ident.Map.t; - (** Gathers the names of the scopes, structs and enums *) + (** Gathers the names of the scopes, structs and enums *) field_idmap : StructField.t StructName.Map.t Ident.Map.t; - (** The names of the struct fields. Names of fields can be shared between - different structs *) + (** The names of the struct fields. Names of fields can be shared between + different structs. Note that fields from submodules are included here for the root module, because disambiguating there is helpful. *) constructor_idmap : EnumConstructor.t EnumName.Map.t Ident.Map.t; - (** The names of the enum constructors. Constructor names can be shared - between different enums *) - scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) + (** The names of the enum constructors. Constructor names can be shared + between different enums. Note that constructors from its submodules are included here for the root module, because disambiguating there is helpful. *) topdefs : TopdefName.t Ident.Map.t; (** Global definitions *) + used_modules : ModuleName.t Ident.Map.t; (** Module aliases and the modules they point to *) +} +(** Context for name resolution, valid within a given module *) + +type context = { + scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) topdef_types : typ TopdefName.Map.t; (** Types associated with the global definitions *) structs : struct_context StructName.Map.t; @@ -87,9 +88,12 @@ type context = { enums : enum_context EnumName.Map.t; (** For each enum, its context *) var_typs : var_sig ScopeVar.Map.t; (** The signatures of each scope variable declared *) - modules : context ModuleName.Map.t; + modules : module_context ModuleName.Map.t; + (** The map to the interfaces of all modules (transitively) used by the program. References are made through [local.used_modules] *) + local : module_context; + (** Local context of the root module corresponding to the program being analysed *) } -(** Main context used throughout {!module: Desugared.From_surface} *) +(** Global context used throughout {!module: Surface.Desugaring} *) (** {1 Helpers} *) @@ -101,6 +105,12 @@ val raise_unknown_identifier : string -> Ident.t Mark.pos -> 'a (** Function to call whenever an identifier used somewhere has not been declared in the program previously *) +val get_modname : context -> Ident.t Mark.pos -> ModuleName.t +(** Emits a user error if the module name is not found *) + +val get_module_ctx : context -> Ident.t Mark.pos -> context +(** Emits a user error if the module name is not found *) + val get_var_typ : context -> ScopeVar.t -> typ (** Gets the type associated to an uid *) @@ -166,5 +176,8 @@ val process_type : context -> Surface.Ast.typ -> typ (** {1 API} *) -val form_context : Surface.Ast.program -> context +val form_context : + Surface.Ast.program * ModuleName.t Ident.Map.t + -> (Surface.Ast.interface * ModuleName.t Ident.Map.t) ModuleName.Map.t + -> context (** Derive the context from metadata, in one pass over the declarations *) diff --git a/compiler/driver.ml b/compiler/driver.ml index c818db8d..374bb6f6 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -29,66 +29,86 @@ let modname_of_file f = let load_module_interfaces options includes program = (* Recurse into program modules, looking up files in [using] and loading them *) + if program.Surface.Ast.program_used_modules <> [] then + Message.emit_debug "Loading module interfaces..."; let includes = includes |> List.map (fun d -> File.Tree.build (options.Cli.path_rewrite d)) |> List.fold_left File.Tree.union File.Tree.empty in let err_req_pos chain = - List.map (fun m -> Some "Module required from", ModuleName.pos m) chain + List.map (fun mpos -> Some "Module required from", mpos) chain in - let find_module req_chain m = - let fname_base = ModuleName.to_string m in - let required_from_file = Pos.get_file (ModuleName.pos m) in + let find_module req_chain (mname, mpos) = + let required_from_file = Pos.get_file mpos in let includes = File.Tree.union includes (File.Tree.build (File.dirname required_from_file)) in match List.filter_map - (fun (ext, _) -> File.Tree.lookup includes (fname_base ^ ext)) + (fun (ext, _) -> File.Tree.lookup includes (mname ^ ext)) extensions with | [] -> Message.raise_multispanned_error - (err_req_pos (m :: req_chain)) - "Required module not found: %a" ModuleName.format m + (err_req_pos (mpos :: req_chain)) + "Required module not found: @{%s@}" mname | [f] -> f | ms -> Message.raise_multispanned_error - (err_req_pos (m :: req_chain)) - "Required module %a matches multiple files: %a" ModuleName.format m + (err_req_pos (mpos :: req_chain)) + "Required module @{%s@} matches multiple files:@;<1 2>%a" mname (Format.pp_print_list ~pp_sep:Format.pp_print_space File.format) ms in - let load_file f = - let (mname, intf), using = - Surface.Parser_driver.load_interface (Cli.FileName f) - in - (ModuleName.of_string mname, intf), using + (* modulename * program * (id -> modulename) *) + let rec aux req_chain seen uses = + List.fold_left (fun (seen, use_map) use -> + let f = find_module req_chain use.Surface.Ast.mod_use_name in + match File.Map.find_opt f seen with + | Some (Some (modname, _, _)) -> + seen, + Ident.Map.add + (Mark.remove use.Surface.Ast.mod_use_alias) modname use_map + | Some None -> + Message.raise_multispanned_error + (err_req_pos (Mark.get use.Surface.Ast.mod_use_name :: req_chain)) + "Circular module dependency" + | None -> + let intf = Surface.Parser_driver.load_interface (Cli.FileName f) in + let modname = ModuleName.fresh use.Surface.Ast.mod_use_name in + let seen = File.Map.add f None seen in + let seen, sub_use_map = + aux + (Mark.get use.Surface.Ast.mod_use_name :: req_chain) + seen + intf.Surface.Ast.intf_submodules + in + File.Map.add f (Some (modname, intf, sub_use_map)) seen, + Ident.Map.add + (Mark.remove use.Surface.Ast.mod_use_alias) modname use_map) + (seen, Ident.Map.empty) uses in - let rec aux req_chain acc modules = - List.fold_left - (fun acc mname -> - let m = ModuleName.of_string mname in - if List.exists (fun (m1, _) -> ModuleName.equal m m1) acc then acc - else - let f = find_module req_chain m in - let (m', intf), using = load_file f in - if not (ModuleName.equal m m') then - Message.raise_multispanned_error - ((Some "Module name declaration", ModuleName.pos m') - :: err_req_pos (m :: req_chain)) - "Mismatching module name declaration:"; - let acc = (m', intf) :: acc in - aux (m :: req_chain) acc using) - acc modules + let seen = + match program.Surface.Ast.program_module_name with + | Some m -> + let file = Pos.get_file (Mark.get m) in + File.Map.singleton file None + | None -> File.Map.empty in - let program_modules = - aux [] [] (List.map fst program.Surface.Ast.program_modules) - |> List.map (fun (m, i) -> (m : ModuleName.t :> string Mark.pos), i) + let file_module_map, root_uses = + aux [] seen program.Surface.Ast.program_used_modules in - { program with Surface.Ast.program_modules } + let modules = + File.Map.fold + (fun _ info acc -> match info with + | None -> acc + | Some (mname, intf, use_map) -> + ModuleName.Map.add mname (intf, use_map) acc) + file_module_map ModuleName.Map.empty + in + root_uses, modules module Passes = struct (* Each pass takes only its cli options, then calls upon its dependent passes @@ -98,23 +118,20 @@ module Passes = struct Message.emit_debug "@{=@} @{%s@} @{=@}" (String.uppercase_ascii s) - let surface options ~includes : Surface.Ast.program = + let surface options : Surface.Ast.program = debug_pass_name "surface"; let prg = Surface.Parser_driver.parse_top_level_file options.Cli.input_src in - let prg = Surface.Fill_positions.fill_pos_with_legislative_info prg in - load_module_interfaces options includes prg + Surface.Fill_positions.fill_pos_with_legislative_info prg let desugared options ~includes : Desugared.Ast.program * Desugared.Name_resolution.context = - let prg = surface options ~includes in + let prg = surface options in + let mod_uses, modules = load_module_interfaces options includes prg in debug_pass_name "desugared"; Message.emit_debug "Name resolution..."; - let ctx = Desugared.Name_resolution.form_context prg in - (* let scope_uid = get_scope_uid options backend ctx in - * (\* This uid is a Desugared identifier *\) - * let variable_uid = get_variable_uid options backend ctx scope_uid in *) + let ctx = Desugared.Name_resolution.form_context (prg, mod_uses) modules in Message.emit_debug "Desugaring..."; let prg = Desugared.From_surface.translate_program ctx prg in Message.emit_debug "Disambiguating..."; @@ -122,16 +139,10 @@ module Passes = struct Message.emit_debug "Linting..."; Desugared.Linting.lint_program prg; prg, ctx - (* Note: we forward the name resolution context throughout in order to locate - uids from strings. Maybe a reduced form should be included directly in - [prg] for that purpose *) let scopelang options ~includes : - untyped Scopelang.Ast.program - * Desugared.Name_resolution.context - * Desugared.Dependency.ExceptionsDependencies.t - Desugared.Ast.ScopeDef.Map.t = - let prg, ctx = desugared options ~includes in + untyped Scopelang.Ast.program = + let prg, _ = desugared options ~includes in debug_pass_name "scopelang"; let exceptions_graphs = Scopelang.From_desugared.build_exceptions_graph prg @@ -139,7 +150,7 @@ module Passes = struct let prg = Scopelang.From_desugared.translate_program prg exceptions_graphs in - prg, ctx, exceptions_graphs + prg let dcalc : type ty. @@ -149,10 +160,9 @@ module Passes = struct check_invariants:bool -> typed:ty mark -> ty Dcalc.Ast.program - * Desugared.Name_resolution.context * Scopelang.Dependency.TVertex.t list = fun options ~includes ~optimize ~check_invariants ~typed -> - let prg, ctx, _ = scopelang options ~includes in + let prg = scopelang options ~includes in debug_pass_name "dcalc"; let type_ordering = Scopelang.Dependency.check_type_cycles prg.program_ctx.ctx_structs @@ -199,7 +209,7 @@ module Passes = struct (Message.raise_internal_error "Some Dcalc invariants are invalid") | _ -> Message.raise_error "--check_invariants cannot be used with --no-typing"); - prg, ctx, type_ordering + prg, type_ordering let lcalc (type ty) @@ -211,9 +221,8 @@ module Passes = struct ~avoid_exceptions ~closure_conversion : untyped Lcalc.Ast.program - * Desugared.Name_resolution.context * Scopelang.Dependency.TVertex.t list = - let prg, ctx, type_ordering = + let prg, type_ordering = dcalc options ~includes ~optimize ~check_invariants ~typed in debug_pass_name "lcalc"; @@ -265,7 +274,7 @@ module Passes = struct prg | Custom _ -> assert false) in - prg, ctx, type_ordering + prg, type_ordering let scalc options @@ -275,42 +284,34 @@ module Passes = struct ~avoid_exceptions ~closure_conversion : Scalc.Ast.program - * Desugared.Name_resolution.context * Scopelang.Dependency.TVertex.t list = - let prg, ctx, type_ordering = + let prg, type_ordering = lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed ~avoid_exceptions ~closure_conversion in debug_pass_name "scalc"; - Scalc.From_lcalc.translate_program prg, ctx, type_ordering + Scalc.From_lcalc.translate_program prg, type_ordering end module Commands = struct open Cmdliner - let get_scope_uid (ctxt : Desugared.Name_resolution.context) (scope : string) - = - match Ident.Map.find_opt scope ctxt.typedefs with - | Some (Desugared.Name_resolution.TScope (uid, _)) -> uid - | _ -> + let get_scope_uid (ctx: decl_ctx) (scope : string): ScopeName.t + = + if String.contains scope '.' then + Message.raise_error "Only references to the top-level module are allowed"; + try Ident.Map.find scope ctx.ctx_scope_index with + | Ident.Map.Not_found _ -> Message.raise_error "There is no scope @{\"%s\"@} inside the program." scope (* TODO: this is very weird but I'm trying to maintain the current behaviour for now *) - let get_random_scope_uid (ctxt : Desugared.Name_resolution.context) = - let _, scope = - try - Shared_ast.Ident.Map.filter_map - (fun _ -> function - | Desugared.Name_resolution.TScope (uid, _) -> Some uid - | _ -> None) - ctxt.typedefs - |> Shared_ast.Ident.Map.choose - with Not_found -> - Message.raise_error "There isn't any scope inside the program." - in - scope + let get_random_scope_uid (ctx: decl_ctx): ScopeName.t = + match Ident.Map.choose_opt ctx.ctx_scope_index with + | Some (_, name) -> name + | None -> + Message.raise_error "There isn't any scope inside the program." let get_variable_uid (ctxt : Desugared.Name_resolution.context) @@ -333,7 +334,7 @@ module Commands = struct "Variable @{\"%s\"@} not found inside scope @{\"%a\"@}" variable ScopeName.format scope_uid | Some - (Desugared.Name_resolution.SubScope (subscope_var_name, subscope_name)) + (SubScope (subscope_var_name, subscope_name)) -> ( match second_part with | None -> @@ -353,7 +354,7 @@ module Commands = struct Ident.Map.find_opt second_part (ScopeName.Map.find subscope_name ctxt.scopes).var_idmap with - | Some (Desugared.Name_resolution.ScopeVar v) -> + | Some (ScopeVar v) -> Desugared.Ast.ScopeDef.SubScopeVar (subscope_var_name, v, Pos.no_pos) | _ -> Message.raise_error @@ -362,7 +363,7 @@ module Commands = struct arguments." second_part SubScopeName.format subscope_var_name ScopeName.format scope_uid)) - | Some (Desugared.Name_resolution.ScopeVar v) -> + | Some (ScopeVar v) -> Desugared.Ast.ScopeDef.Var ( v, Option.map @@ -389,7 +390,7 @@ module Commands = struct ~output_file ?ext () let makefile options output = - let prg = Passes.surface options ~includes:[] in + let prg = Passes.surface options in let backend_extensions_list = [".tex"] in let source_file = Cli.input_src_file options.Cli.input_src in let output_file, with_output = get_output options ~ext:".d" output in @@ -415,7 +416,7 @@ module Commands = struct Term.(const makefile $ Cli.Flags.Global.options $ Cli.Flags.output) let html options output print_only_law wrap_weaved_output = - let prg = Passes.surface options ~includes:[] in + let prg = Passes.surface options in Message.emit_debug "Weaving literate program into HTML"; let output_file, with_output = get_output_format options ~ext:".html" output @@ -444,7 +445,7 @@ module Commands = struct $ Cli.Flags.wrap_weaved_output) let latex options output print_only_law wrap_weaved_output = - let prg = Passes.surface options ~includes:[] in + let prg = Passes.surface options in Message.emit_debug "Weaving literate program into LaTeX"; let output_file, with_output = get_output_format options ~ext:".tex" output @@ -473,8 +474,12 @@ module Commands = struct $ Cli.Flags.wrap_weaved_output) let exceptions options includes ex_scope ex_variable = - let _, ctxt, exceptions_graphs = Passes.scopelang options ~includes in - let scope_uid = get_scope_uid ctxt ex_scope in + let prg, ctxt = Passes.desugared options ~includes in + Passes.debug_pass_name "scopelang"; + let exceptions_graphs = + Scopelang.From_desugared.build_exceptions_graph prg + in + let scope_uid = get_scope_uid prg.program_ctx ex_scope in let variable_uid = get_variable_uid ctxt scope_uid ex_variable in Desugared.Print.print_exceptions_graph scope_uid variable_uid (Desugared.Ast.ScopeDef.Map.find variable_uid exceptions_graphs) @@ -496,13 +501,13 @@ module Commands = struct $ Cli.Flags.ex_variable) let scopelang options includes output ex_scope_opt = - let prg, ctx, _ = Passes.scopelang options ~includes in + let prg = Passes.scopelang options ~includes in let _output_file, with_output = get_output_format options output in with_output @@ fun fmt -> match ex_scope_opt with | Some scope -> - let scope_uid = get_scope_uid ctx scope in + let scope_uid = get_scope_uid prg.program_ctx scope in Scopelang.Print.scope ~debug:options.Cli.debug prg.program_ctx fmt (scope_uid, ScopeName.Map.find scope_uid prg.program_scopes); Format.pp_print_newline fmt () @@ -525,7 +530,7 @@ module Commands = struct $ Cli.Flags.ex_scope_opt) let typecheck options includes = - let prg, _, _ = Passes.scopelang options ~includes in + let prg = Passes.scopelang options ~includes in Message.emit_debug "Typechecking..."; let _type_ordering = Scopelang.Dependency.check_type_cycles prg.program_ctx.ctx_structs @@ -547,7 +552,7 @@ module Commands = struct let dcalc typed options includes output optimize ex_scope_opt check_invariants = - let prg, ctx, _ = + let prg, _ = Passes.dcalc options ~includes ~optimize ~check_invariants ~typed in let _output_file, with_output = get_output_format options output in @@ -555,7 +560,7 @@ module Commands = struct @@ fun fmt -> match ex_scope_opt with | Some scope -> - let scope_uid = get_scope_uid ctx scope in + let scope_uid = get_scope_uid prg.decl_ctx scope in Print.scope ~debug:options.Cli.debug prg.decl_ctx fmt ( scope_uid, Option.get @@ -568,7 +573,7 @@ module Commands = struct prg.code_items) ); Format.pp_print_newline fmt () | None -> - let scope_uid = get_random_scope_uid ctx in + let scope_uid = get_random_scope_uid prg.decl_ctx in (* TODO: ??? *) let prg_dcalc_expr = Expr.unbox (Program.to_expr prg scope_uid) in Format.fprintf fmt "%a\n" @@ -602,14 +607,14 @@ module Commands = struct ex_scope_opt check_invariants disable_counterexamples = - let prg, ctx, _ = + let prg, _ = Passes.dcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed in Verification.Globals.setup ~optimize ~disable_counterexamples; let vcs = Verification.Conditions.generate_verification_conditions prg - (Option.map (get_scope_uid ctx) ex_scope_opt) + (Option.map (get_scope_uid prg.decl_ctx) ex_scope_opt) in Verification.Solver.solve_vc prg.decl_ctx vcs @@ -654,12 +659,12 @@ module Commands = struct let interpret_dcalc typed options includes optimize check_invariants ex_scope = - let prg, ctx, _ = + let prg, _ = Passes.dcalc options ~includes ~optimize ~check_invariants ~typed in Interpreter.load_runtime_modules prg; print_interpretation_results options Interpreter.interpret_program_dcalc prg - (get_scope_uid ctx ex_scope) + (get_scope_uid prg.decl_ctx ex_scope) let interpret_cmd = let f no_typing = @@ -691,7 +696,7 @@ module Commands = struct avoid_exceptions closure_conversion ex_scope_opt = - let prg, ctx, _ = + let prg, _ = Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion ~typed in @@ -700,7 +705,7 @@ module Commands = struct @@ fun fmt -> match ex_scope_opt with | Some scope -> - let scope_uid = get_scope_uid ctx scope in + let scope_uid = get_scope_uid prg.decl_ctx scope in Print.scope ~debug:options.Cli.debug prg.decl_ctx fmt (scope_uid, Program.get_scope_body prg scope_uid); Format.pp_print_newline fmt () @@ -739,13 +744,13 @@ module Commands = struct avoid_exceptions closure_conversion ex_scope = - let prg, ctx, _ = + let prg, _ = Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion ~typed in Interpreter.load_runtime_modules prg; print_interpretation_results options Interpreter.interpret_program_lcalc prg - (get_scope_uid ctx ex_scope) + (get_scope_uid prg.decl_ctx ex_scope) let interpret_lcalc_cmd = let f no_typing = @@ -777,7 +782,7 @@ module Commands = struct check_invariants avoid_exceptions closure_conversion = - let prg, _, type_ordering = + let prg, type_ordering = Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion ~typed:Expr.typed in @@ -814,7 +819,7 @@ module Commands = struct avoid_exceptions closure_conversion ex_scope_opt = - let prg, ctx, _ = + let prg, _ = Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in @@ -823,7 +828,7 @@ module Commands = struct @@ fun fmt -> match ex_scope_opt with | Some scope -> - let scope_uid = get_scope_uid ctx scope in + let scope_uid = get_scope_uid prg.decl_ctx scope in Scalc.Print.format_item ~debug:options.Cli.debug prg.decl_ctx fmt (List.find (function @@ -860,7 +865,7 @@ module Commands = struct check_invariants avoid_exceptions closure_conversion = - let prg, _, type_ordering = + let prg, type_ordering = Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in @@ -889,7 +894,7 @@ module Commands = struct $ Cli.Flags.closure_conversion) let r options includes output optimize check_invariants closure_conversion = - let prg, _, type_ordering = + let prg, type_ordering = Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions:false ~closure_conversion in diff --git a/compiler/driver.mli b/compiler/driver.mli index ffa774a7..e5335b0c 100644 --- a/compiler/driver.mli +++ b/compiler/driver.mli @@ -25,7 +25,8 @@ val main : unit -> unit Each pass takes only its cli options, then calls upon its dependent passes (forwarding their options as needed) *) module Passes : sig - val surface : Cli.options -> includes:Cli.raw_file list -> Surface.Ast.program + + val surface : Cli.options -> Surface.Ast.program val desugared : Cli.options -> @@ -36,8 +37,6 @@ module Passes : sig Cli.options -> includes:Cli.raw_file list -> Shared_ast.untyped Scopelang.Ast.program - * Desugared.Name_resolution.context - * Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t val dcalc : Cli.options -> @@ -46,7 +45,6 @@ module Passes : sig check_invariants:bool -> typed:'m Shared_ast.mark -> 'm Dcalc.Ast.program - * Desugared.Name_resolution.context * Scopelang.Dependency.TVertex.t list val lcalc : @@ -58,7 +56,6 @@ module Passes : sig avoid_exceptions:bool -> closure_conversion:bool -> Shared_ast.untyped Lcalc.Ast.program - * Desugared.Name_resolution.context * Scopelang.Dependency.TVertex.t list val scalc : @@ -69,7 +66,6 @@ module Passes : sig avoid_exceptions:bool -> closure_conversion:bool -> Scalc.Ast.program - * Desugared.Name_resolution.context * Scopelang.Dependency.TVertex.t list end @@ -90,7 +86,7 @@ module Commands : sig string option * ((Format.formatter -> 'a) -> 'a) val get_scope_uid : - Desugared.Name_resolution.context -> string -> Shared_ast.ScopeName.t + Shared_ast.decl_ctx -> string -> Shared_ast.ScopeName.t val get_variable_uid : Desugared.Name_resolution.context -> diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index f1cc96ae..89549ff4 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -405,26 +405,20 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = let replace_fun_typs t = if type_contains_arrow t then Mark.copy t TAny else t in - let rec convert_ctx ctx = - { - ctx_struct_fields = ctx.ctx_struct_fields; - ctx_modules = ModuleName.Map.map convert_ctx ctx.ctx_modules; - ctx_structs = - StructName.Map.map - (StructField.Map.map replace_fun_typs) - ctx.ctx_structs; - ctx_enums = - EnumName.Map.map - (EnumConstructor.Map.map replace_fun_typs) - ctx.ctx_enums; - ctx_scopes = ctx.ctx_scopes; - ctx_topdefs = ctx.ctx_topdefs; - (* Toplevel definitions may not contain scope calls or take functions as - arguments at the moment, which ensures that their interfaces aren't - changed by the conversion *) - } - in - convert_ctx p.decl_ctx + { + p.decl_ctx with + ctx_structs = + StructName.Map.map + (StructField.Map.map replace_fun_typs) + p.decl_ctx.ctx_structs; + ctx_enums = + EnumName.Map.map + (EnumConstructor.Map.map replace_fun_typs) + p.decl_ctx.ctx_enums; + (* Toplevel definitions may not contain scope calls or take functions as + arguments at the moment, which ensures that their interfaces aren't + changed by the conversion *) + } in Bindlib.box_apply (fun new_code_items -> diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index aa210798..f31a907b 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -439,7 +439,7 @@ let run options = if not options.Cli.trace then Message.raise_error "This plugin requires the --trace flag."; - let prg, _, type_ordering = + let prg, type_ordering = Driver.Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion ~typed:Expr.typed in diff --git a/compiler/plugins/explain.ml b/compiler/plugins/explain.ml index b93b07c6..6f2f5255 100644 --- a/compiler/plugins/explain.ml +++ b/compiler/plugins/explain.ml @@ -1387,12 +1387,12 @@ let options = $ base_src_url) let run includes optimize ex_scope explain_options global_options = - let prg, ctx, _ = + let prg, _ = Driver.Passes.dcalc global_options ~includes ~optimize ~check_invariants:false ~typed:Expr.typed in Interpreter.load_runtime_modules prg; - let scope = Driver.Commands.get_scope_uid ctx ex_scope in + let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in (* let result_expr, env = interpret_program prg scope in *) let g, base_vars, env = program_to_graph explain_options prg scope in log "Base variables detected: @[%a@]" diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index e7428745..1582e0ea 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -214,7 +214,7 @@ let run closure_conversion ex_scope options = - let prg, ctx, _ = + let prg, _ = Driver.Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion ~typed:Expr.typed in @@ -223,7 +223,7 @@ let run in with_output @@ fun fmt -> - let scope_uid = Driver.Commands.get_scope_uid ctx ex_scope in + let scope_uid = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in Message.emit_debug "Writing JSON schema corresponding to the scope '%a' to the file %s..." ScopeName.format scope_uid diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml index 4b799dd8..4dad71f5 100644 --- a/compiler/plugins/lazy_interp.ml +++ b/compiler/plugins/lazy_interp.ml @@ -259,12 +259,12 @@ let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : (* -- Plugin registration -- *) let run includes optimize check_invariants ex_scope options = - let prg, ctx, _ = + let prg, _ = Driver.Passes.dcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed in Interpreter.load_runtime_modules prg; - let scope = Driver.Commands.get_scope_uid ctx ex_scope in + let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in let result_expr, _env = interpret_program prg scope in let fmt = Format.std_formatter in Expr.format fmt result_expr diff --git a/compiler/plugins/python.ml b/compiler/plugins/python.ml index ec0397e7..49bf1d1c 100644 --- a/compiler/plugins/python.ml +++ b/compiler/plugins/python.ml @@ -31,7 +31,7 @@ let run closure_conversion options = let open Driver.Commands in - let prg, _, type_ordering = + let prg, type_ordering = Driver.Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index fe46d57d..6189bd70 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -58,10 +58,10 @@ type 'm scope_decl = { type 'm program = { program_module_name : ModuleName.t option; + program_ctx : decl_ctx; + program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t; program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; program_topdefs : ('m expr * typ) TopdefName.Map.t; - program_modules : nil program ModuleName.Map.t; - program_ctx : decl_ctx; program_lang : Cli.backend_lang; } @@ -77,42 +77,34 @@ let type_rule decl_ctx env = function let pos = Expr.mark_pos m in Call (sc_name, ssc_name, Typed { pos; ty = Mark.add pos TAny }) -let type_program (prg : 'm program) : typed program = +let type_program (type m) (prg : m program) : typed program = (* Caution: this environment building code is very similar to that in desugared/disambiguate.ml. Any edits should probably be reflected. *) - let base_typing_env prg = - let env = Typing.Env.empty prg.program_ctx in - let env = - TopdefName.Map.fold - (fun name ty env -> Typing.Env.add_toplevel_var name ty env) - prg.program_ctx.ctx_topdefs env - in - let env = - ScopeName.Map.fold - (fun scope_name scope_decl env -> - let sg = (Mark.remove scope_decl).scope_sig in - let vars = - ScopeVar.Map.map (fun { svar_out_ty; _ } -> svar_out_ty) sg - in - let in_vars = - ScopeVar.Map.map (fun { svar_in_ty; _ } -> svar_in_ty) sg - in - Typing.Env.add_scope scope_name ~vars ~in_vars env) - prg.program_scopes env - in - env - in - let rec build_typing_env prg = - ModuleName.Map.fold - (fun modname prg -> - Typing.Env.add_module modname ~module_env:(build_typing_env prg)) - prg.program_modules (base_typing_env prg) + let env = Typing.Env.empty prg.program_ctx in + let env = + TopdefName.Map.fold + (fun name ty env -> Typing.Env.add_toplevel_var name ty env) + prg.program_ctx.ctx_topdefs env in let env = - ModuleName.Map.fold - (fun modname prg -> - Typing.Env.add_module modname ~module_env:(build_typing_env prg)) - prg.program_modules (base_typing_env prg) + ScopeName.Map.fold + (fun scope_name _info env -> + let scope_sig = + match ScopeName.path scope_name with + | [] -> (Mark.remove (ScopeName.Map.find scope_name prg.program_scopes)).scope_sig + | p -> + let m = List.hd (List.rev p) in + let scope = ScopeName.Map.find scope_name (ModuleName.Map.find m prg.program_modules) in + (Mark.remove scope).scope_sig + in + let vars = + ScopeVar.Map.map (fun { svar_out_ty; _ } -> svar_out_ty) scope_sig + in + let in_vars = + ScopeVar.Map.map (fun { svar_in_ty; _ } -> svar_in_ty) scope_sig + in + Typing.Env.add_scope scope_name ~vars ~in_vars env) + prg.program_ctx.ctx_scopes env in let program_topdefs = TopdefName.Map.map diff --git a/compiler/scopelang/ast.mli b/compiler/scopelang/ast.mli index ccb33259..654c4fb3 100644 --- a/compiler/scopelang/ast.mli +++ b/compiler/scopelang/ast.mli @@ -51,14 +51,13 @@ type 'm scope_decl = { type 'm program = { program_module_name : ModuleName.t option; - program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; - program_topdefs : ('m expr * typ) TopdefName.Map.t; - program_modules : nil program ModuleName.Map.t; + program_ctx : decl_ctx; + program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t; (* Using [nil] here ensure that program interfaces don't contain any expressions. They won't contain any rules or topdefs, but will still have the scope signatures needed to respect the call convention *) - program_ctx : decl_ctx; + program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; + program_topdefs : ('m expr * typ) TopdefName.Map.t; program_lang : Cli.backend_lang; } - val type_program : 'm program -> typed program diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index 16bdddf2..715d539f 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -31,7 +31,6 @@ type ctx = { scope_var_mapping : target_scope_vars ScopeVar.Map.t; reentrant_vars : typ ScopeVar.Map.t; var_mapping : (D.expr, untyped Ast.expr Var.t) Var.Map.t; - modules : ctx ModuleName.Map.t; } let tag_with_log_entry @@ -61,11 +60,6 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed = | ELocation (SubScopeVar { scope; alias; var }) -> (* When referring to a subscope variable in an expression, we are referring to the output, hence we take the last state. *) - let ctx = - List.fold_left - (fun ctx m -> ModuleName.Map.find m ctx.modules) - ctx (ScopeName.path scope) - in let var = match ScopeVar.Map.find (Mark.remove var) ctx.scope_var_mapping with | WholeVar new_s_var -> Mark.copy var new_s_var @@ -97,27 +91,8 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed = }) m | ELocation (ToplevelVar v) -> Expr.elocation (ToplevelVar v) m - | EDStructAccess { name_opt = None; _ } -> - (* Note: this could only happen if disambiguation was disabled. If we want - to support it, we should still allow this case when the field has only - one possible matching structure *) - Message.raise_spanned_error (Expr.mark_pos m) - "Ambiguous structure field access" - | EDStructAccess { e; field; name_opt = Some name } -> - let e' = translate_expr ctx e in - let field = - let decl_ctx = Program.module_ctx ctx.decl_ctx (StructName.path name) in - try - StructName.Map.find name - (Ident.Map.find field decl_ctx.ctx_struct_fields) - with StructName.Map.Not_found _ | Ident.Map.Not_found _ -> - (* Should not happen after disambiguation *) - Message.raise_spanned_error (Expr.mark_pos m) - "Field @{\"%s\"@} does not belong to structure \ - @{\"%a\"@}" - field StructName.format name - in - Expr.estructaccess ~e:e' ~field ~name m + | EDStructAccess _ -> assert false + (* This shouldn't appear in desugared after disambiguation *) | EScopeCall { scope; args } -> Expr.escopecall ~scope ~args: @@ -168,7 +143,7 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed = | op, `Reversed -> Expr.eapp (Expr.eop op (List.rev tys) m1) (List.rev args) m) | EOp _ -> assert false (* Only allowed within [EApp] *) - | ( EStruct _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | ELit _ + | ( EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | ELit _ | EApp _ | EDefault _ | EPureDefault _ | EIfThenElse _ | EArray _ | EEmptyError | EErrorOnEmpty _ ) as e -> Expr.map ~f:(translate_expr ctx) (e, m) @@ -300,8 +275,7 @@ let scope_to_exception_graphs (scope : D.scope) : List.fold_left (fun exceptions_graphs scope_def_key -> let new_exceptions_graphs = rule_to_exception_graph scope scope_def_key in - D.ScopeDef.Map.union - (fun _ _ _ -> assert false (* there should not be key conflicts *)) + D.ScopeDef.Map.disjoint_union new_exceptions_graphs exceptions_graphs) D.ScopeDef.Map.empty scope_ordering @@ -310,10 +284,9 @@ let build_exceptions_graph (pgrm : D.program) : ScopeName.Map.fold (fun _ scope exceptions_graph -> let new_exceptions_graphs = scope_to_exception_graphs scope in - D.ScopeDef.Map.union - (fun _ _ _ -> assert false (* key conflicts should not happen*)) + D.ScopeDef.Map.disjoint_union new_exceptions_graphs exceptions_graph) - pgrm.program_scopes D.ScopeDef.Map.empty + pgrm.program_root.module_scopes D.ScopeDef.Map.empty (** Transforms a flat list of rules into a tree, taking into account the priorities declared between rules *) @@ -789,26 +762,31 @@ let translate_program (* First we give mappings to all the locations between Desugared and This involves creating a new Scopelang scope variable for every state of a Desugared variable. *) - let rec make_ctx desugared = - let modules = ModuleName.Map.map make_ctx desugared.D.program_modules in - (* Todo: since we rename all scope vars at this point, it would be better to - have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *) - ScopeName.Map.fold - (fun _scope scope_decl ctx -> - ScopeVar.Map.fold - (fun scope_var (states : D.var_or_states) ctx -> - let var_name, var_pos = ScopeVar.get_info scope_var in - let new_var = - match states with - | D.WholeVar -> WholeVar (ScopeVar.fresh (var_name, var_pos)) - | States states -> - let var_prefix = var_name ^ "_" in - let state_var state = - ScopeVar.fresh - (Mark.map (( ^ ) var_prefix) (StateName.get_info state)) - in - States (List.map (fun state -> state, state_var state) states) - in + let ctx = + let ctx = + { + scope_var_mapping = ScopeVar.Map.empty; + var_mapping = Var.Map.empty; + reentrant_vars = ScopeVar.Map.empty; + decl_ctx = desugared.program_ctx; + } + in + let add_scope_mappings modul ctx = + ScopeName.Map.fold (fun _ scdef ctx -> + ScopeVar.Map.fold + (fun scope_var (states : D.var_or_states) ctx -> + let var_name, var_pos = ScopeVar.get_info scope_var in + let new_var = + match states with + | D.WholeVar -> WholeVar (ScopeVar.fresh (var_name, var_pos)) + | States states -> + let var_prefix = var_name ^ "_" in + let state_var state = + ScopeVar.fresh + (Mark.map (( ^ ) var_prefix) (StateName.get_info state)) + in + States (List.map (fun state -> state, state_var state) states) + in let reentrant = let state = match states with @@ -819,7 +797,7 @@ let translate_program match D.ScopeDef.Map.find_opt (Var (scope_var, state)) - scope_decl.D.scope_defs + scdef.D.scope_defs with | Some { @@ -830,96 +808,53 @@ let translate_program Some scope_def_typ | _ -> None in - { - ctx with - scope_var_mapping = - ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping; + { + ctx with + scope_var_mapping = + ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping; reentrant_vars = Option.fold reentrant ~some:(fun ty -> ScopeVar.Map.add scope_var ty ctx.reentrant_vars) ~none:ctx.reentrant_vars; - }) - scope_decl.D.scope_vars ctx) - desugared.D.program_scopes - { - scope_var_mapping = ScopeVar.Map.empty; - var_mapping = Var.Map.empty; - reentrant_vars = ScopeVar.Map.empty; - decl_ctx = desugared.program_ctx; - modules; - } - in - let ctx = make_ctx desugared in - let rec gather_scope_vars acc modules = - ModuleName.Map.fold - (fun _modname mctx (vmap, reentr) -> - let vmap, reentr = gather_scope_vars (vmap, reentr) mctx.modules in - ( ScopeVar.Map.union - (fun _ _ -> assert false) - vmap mctx.scope_var_mapping, - ScopeVar.Map.union - (fun _ _ -> assert false) - reentr mctx.reentrant_vars )) - modules acc - in - let ctx = - let scope_var_mapping, reentrant_vars = - gather_scope_vars (ctx.scope_var_mapping, ctx.reentrant_vars) ctx.modules + }) + scdef.D.scope_vars ctx) + modul.D.module_scopes ctx in - { ctx with scope_var_mapping; reentrant_vars } + (* Todo: since we rename all scope vars at this point, it would be better to + have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *) + ModuleName.Map.fold (fun _ m ctx -> add_scope_mappings m ctx) + desugared.D.program_modules + (add_scope_mappings (desugared.D.program_root) ctx) in - let rec process_decl_ctx ctx decl_ctx = + let decl_ctx = let ctx_scopes = ScopeName.Map.map (fun out_str -> - let out_struct_fields = - ScopeVar.Map.fold - (fun var fld out_map -> - let var' = - match ScopeVar.Map.find var ctx.scope_var_mapping with - | WholeVar v -> v - | States l -> snd (List.hd (List.rev l)) - in - ScopeVar.Map.add var' fld out_map) - out_str.out_struct_fields ScopeVar.Map.empty - in - { out_str with out_struct_fields }) - decl_ctx.ctx_scopes + let out_struct_fields = + ScopeVar.Map.fold + (fun var fld out_map -> + let var' = + match ScopeVar.Map.find var ctx.scope_var_mapping with + | WholeVar v -> v + | States l -> snd (List.hd (List.rev l)) + in + ScopeVar.Map.add var' fld out_map) + out_str.out_struct_fields ScopeVar.Map.empty + in + { out_str with out_struct_fields }) + desugared.program_ctx.ctx_scopes in - { - decl_ctx with - ctx_modules = - ModuleName.Map.mapi - (fun modname decl_ctx -> - let ctx = ModuleName.Map.find modname ctx.modules in - process_decl_ctx ctx decl_ctx) - decl_ctx.ctx_modules; - ctx_scopes; - } + { desugared.program_ctx with ctx_scopes } in - let rec process_modules program_ctx desugared = - ModuleName.Map.mapi - (fun modname m_desugared -> - let ctx = ModuleName.Map.find modname ctx.modules in - { - Ast.program_module_name = Some modname; - Ast.program_topdefs = TopdefName.Map.empty; - program_scopes = - ScopeName.Map.map - (translate_scope_interface ctx) - m_desugared.D.program_scopes; - program_ctx = ModuleName.Map.find modname program_ctx.ctx_modules; - program_modules = - process_modules - (ModuleName.Map.find modname program_ctx.ctx_modules) - m_desugared; - Ast.program_lang = desugared.program_lang; - }) + let ctx = { ctx with decl_ctx }in + let program_modules = + ModuleName.Map.map (fun m -> + ScopeName.Map.map + (translate_scope_interface ctx) + m.D.module_scopes) desugared.D.program_modules in - let program_ctx = process_decl_ctx ctx desugared.D.program_ctx in - let program_modules = process_modules program_ctx desugared in let program_topdefs = TopdefName.Map.mapi (fun id -> function @@ -927,18 +862,18 @@ let translate_program | None, (_, pos) -> Message.raise_spanned_error pos "No definition found for %a" TopdefName.format id) - desugared.program_topdefs + desugared.program_root.module_topdefs in let program_scopes = ScopeName.Map.map (translate_scope ctx exc_graphs) - desugared.D.program_scopes + desugared.D.program_root.module_scopes in { - Ast.program_module_name = desugared.D.program_module_name; + Ast.program_module_name = Option.map ModuleName.fresh desugared.D.program_module_name; Ast.program_topdefs; Ast.program_scopes; - Ast.program_ctx; + Ast.program_ctx = ctx.decl_ctx; Ast.program_modules; Ast.program_lang = desugared.program_lang; } diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 7e2d906f..9f1dbe1e 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -102,6 +102,10 @@ module SubScopeName = end) () +type scope_var_or_subscope = + | ScopeVar of ScopeVar.t + | SubScope of SubScopeName.t * ScopeName.t + module StateName = Uid.Gen (struct @@ -135,7 +139,6 @@ type desugared = ; overloaded : yes ; resolved : no ; syntacticNames : yes - ; resolvedNames : no ; scopeVarStates : yes ; scopeVarSimpl : no ; explicitScopes : yes @@ -143,6 +146,9 @@ type desugared = ; defaultTerms : yes ; exceptions : no ; custom : no > +(* Technically, desugared before name resolution has [syntacticNames: yes; resolvedNames: no], and after name resolution has the opposite; but the disambiguation being done by the typer, we don't encode this invariant at the type level. + +Indeed, unfortunately, we cannot express the [ -> ] that would be needed for the typing function. *) type scopelang = < monomorphic : yes @@ -150,7 +156,6 @@ type scopelang = ; overloaded : no ; resolved : yes ; syntacticNames : no - ; resolvedNames : yes ; scopeVarStates : no ; scopeVarSimpl : yes ; explicitScopes : yes @@ -165,7 +170,6 @@ type dcalc = ; overloaded : no ; resolved : yes ; syntacticNames : no - ; resolvedNames : yes ; scopeVarStates : no ; scopeVarSimpl : no ; explicitScopes : no @@ -180,7 +184,6 @@ type lcalc = ; overloaded : no ; resolved : yes ; syntacticNames : no - ; resolvedNames : yes ; scopeVarStates : no ; scopeVarSimpl : no ; explicitScopes : no @@ -199,7 +202,6 @@ type dcalc_lcalc_features = ; overloaded : no ; resolved : yes ; syntacticNames : no - ; resolvedNames : yes ; scopeVarStates : no ; scopeVarSimpl : no ; explicitScopes : no @@ -535,8 +537,8 @@ and ('a, 'b, 'm) base_gexpr = e : ('a, 'm) gexpr; field : StructField.t; } - -> ('a, < resolvedNames : yes ; .. >, 'm) base_gexpr - (** Resolved struct/enums, after [desugared] *) + -> ('a, < .. >, 'm) base_gexpr + (** Resolved struct/enums, after name resolution in [desugared] *) (* Lambda-like *) | EExternal : { name : external_ref Mark.pos; @@ -651,8 +653,8 @@ type 'e code_item = | ScopeDef of ScopeName.t * 'e scope_body | Topdef of TopdefName.t * typ * 'e -(* A chained list, but with a binder for each element into the next: [x := let a - = e1 in e2] is thus [Cons (e1, {a. Cons (e2, {x. Nil})})] *) +(** A chained list, but with a binder for each element into the next: [x := let a + = e1 in e2] is thus [Cons (e1, {a. Cons (e2, {x. Nil})})] *) type 'e code_item_list = | Nil | Cons of 'e code_item * ('e, 'e code_item_list) binder @@ -666,14 +668,20 @@ type scope_info = { out_struct_fields : StructField.t ScopeVar.Map.t; } +type module_tree = M of module_tree ModuleName.Map.t [@@caml.unboxed] +(** In practice, this is a DAG: beware of repeated names *) + type decl_ctx = { ctx_enums : enum_ctx; ctx_structs : struct_ctx; - ctx_struct_fields : StructField.t StructName.Map.t Ident.Map.t; - (** needed for disambiguation (desugared -> scope) *) ctx_scopes : scope_info ScopeName.Map.t; ctx_topdefs : typ TopdefName.Map.t; - ctx_modules : decl_ctx ModuleName.Map.t; + ctx_struct_fields : StructField.t StructName.Map.t Ident.Map.t; + (** needed for disambiguation (desugared -> scope) *) + ctx_enum_constrs : EnumConstructor.t EnumName.Map.t Ident.Map.t; + ctx_scope_index : ScopeName.t Ident.Map.t; + (** only used to lookup scopes (in the root module) specified from the cli *) + ctx_modules : module_tree; } type 'e program = { diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index 146929f4..0fd5654e 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -134,7 +134,7 @@ val estructaccess : field:StructField.t -> e:('a, 'm) boxed_gexpr -> 'm mark -> - ((< resolvedNames : yes ; .. > as 'a), 'm) boxed_gexpr + ('a any, 'm) boxed_gexpr val einj : name:EnumName.t -> diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index c00d4922..140c9b30 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -571,7 +571,6 @@ let rec evaluate_expr : in let ty = try - let ctx = Program.module_ctx ctx path in match Mark.remove name with | External_value name -> TopdefName.Map.find name ctx.ctx_topdefs | External_scope name -> @@ -986,12 +985,13 @@ let load_runtime_modules prg = let obj_file = Dynlink.adapt_filename File.( - (Pos.get_file (ModuleName.pos m) /../ ModuleName.to_string m) ^ ".cmo") + (Pos.get_file (Mark.get (ModuleName.get_info m)) + /../ ModuleName.to_string m) ^ ".cmo") in if not (Sys.file_exists obj_file) then Message.raise_spanned_error ~span_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here") - (ModuleName.pos m) + (Mark.get (ModuleName.get_info m)) "Compiled OCaml object %a not found. Make sure it has been suitably \ compiled." File.format obj_file @@ -1003,20 +1003,18 @@ let load_runtime_modules prg = obj_file Format.pp_print_text (Dynlink.error_message dl_err) in - let rec aux loaded decl_ctx = - ModuleName.Map.fold - (fun mname sub_decl_ctx loaded -> - if ModuleName.Set.mem mname loaded then loaded - else - let loaded = ModuleName.Set.add mname loaded in - let loaded = aux loaded sub_decl_ctx in - load mname; - loaded) - decl_ctx.ctx_modules loaded + let modules_list_topo = + let rec aux acc (M mtree) = + ModuleName.Map.fold + (fun mname sub acc -> + if List.exists (ModuleName.equal mname) acc then acc else + mname :: aux acc sub) + mtree acc + in + List.rev (aux [] prg.decl_ctx.ctx_modules) in - if not (ModuleName.Map.is_empty prg.decl_ctx.ctx_modules) then + if modules_list_topo <> [] then Message.emit_debug "Loading shared modules... %a" - (fun ppf -> ModuleName.Map.format_keys ppf) - prg.decl_ctx.ctx_modules; - let (_loaded : ModuleName.Set.t) = aux ModuleName.Set.empty prg.decl_ctx in - () + (Format.pp_print_list ~pp_sep:Format.pp_print_space ModuleName.format) + modules_list_topo; + List.iter load modules_list_topo diff --git a/compiler/shared_ast/print.mli b/compiler/shared_ast/print.mli index becd837b..3fd201e6 100644 --- a/compiler/shared_ast/print.mli +++ b/compiler/shared_ast/print.mli @@ -74,7 +74,7 @@ module type EXPR_PARAM = sig (** pre-processing on expressions: can be used to skip log calls, etc. *) end -module ExprGen (C : EXPR_PARAM) : sig +module ExprGen (_ : EXPR_PARAM) : sig val expr : Format.formatter -> ('a, 't) gexpr -> unit end diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index 1dd18a5b..c75c1b4c 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -32,15 +32,14 @@ let empty_ctx = { ctx_enums = EnumName.Map.empty; ctx_structs = StructName.Map.empty; - ctx_struct_fields = Ident.Map.empty; ctx_scopes = ScopeName.Map.empty; ctx_topdefs = TopdefName.Map.empty; - ctx_modules = ModuleName.Map.empty; + ctx_struct_fields = Ident.Map.empty; + ctx_enum_constrs = Ident.Map.empty; + ctx_scope_index = Ident.Map.empty; + ctx_modules = M ModuleName.Map.empty; } -let module_ctx ctx path = - List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.ctx_modules) ctx path - let get_scope_body { code_items; _ } scope = match Scope.fold_left ~init:None diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index 58a8f55b..b527ba8f 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -15,17 +15,12 @@ License for the specific language governing permissions and limitations under the License. *) -open Catala_utils open Definitions (** {2 Program declaration context helpers} *) val empty_ctx : decl_ctx -val module_ctx : decl_ctx -> Uid.Path.t -> decl_ctx -(** Follows a path to get the corresponding context for type and value - declarations. *) - (** {2 Transformations} *) val map_exprs : diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index afcfa5e6..75ef112b 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -343,7 +343,6 @@ module Env = struct scopes : A.typ A.ScopeVar.Map.t A.ScopeName.Map.t; scopes_input : A.typ A.ScopeVar.Map.t A.ScopeName.Map.t; toplevel_vars : A.typ A.TopdefName.Map.t; - modules : 'e t A.ModuleName.Map.t; } let empty (decl_ctx : A.decl_ctx) = @@ -363,7 +362,6 @@ module Env = struct scopes = A.ScopeName.Map.empty; scopes_input = A.ScopeName.Map.empty; toplevel_vars = A.TopdefName.Map.empty; - modules = A.ModuleName.Map.empty; } let get t v = Var.Map.find_opt v t.vars @@ -374,9 +372,6 @@ module Env = struct Option.bind (A.ScopeName.Map.find_opt scope t.scopes) (fun vmap -> A.ScopeVar.Map.find_opt var vmap) - let module_env path env = - List.fold_left (fun env m -> A.ModuleName.Map.find m env.modules) env path - let add v tau t = { t with vars = Var.Map.add v tau t.vars } let add_var v typ t = add v (ast_to_typ typ) t @@ -393,19 +388,15 @@ module Env = struct let add_toplevel_var v typ t = { t with toplevel_vars = A.TopdefName.Map.add v typ t.toplevel_vars } - let add_module modname ~module_env t = - { t with modules = A.ModuleName.Map.add modname module_env t.modules } - let open_scope scope_name t = let scope_vars = - A.ScopeVar.Map.union - (fun _ _ -> assert false) + A.ScopeVar.Map.disjoint_union t.scope_vars (A.ScopeName.Map.find scope_name t.scopes) in { t with scope_vars } - let rec dump ppf env = + let dump ppf env = let pp_sep = Format.pp_print_space in Format.pp_open_vbox ppf 0; (* Format.fprintf ppf "structs: @[%a@]@," @@ -420,9 +411,6 @@ module Env = struct Format.fprintf ppf "topdefs: @[%a@]@," (A.TopdefName.Map.format_keys ~pp_sep) env.toplevel_vars; - Format.fprintf ppf "@[modules:@ %a@]" - (A.ModuleName.Map.format dump) - env.modules; Format.pp_close_box ppf () end @@ -480,10 +468,8 @@ and typecheck_expr_top_down : | DesugaredScopeVar { name; _ } | ScopelangScopeVar { name } -> Env.get_scope_var env (Mark.remove name) | SubScopeVar { scope; var; _ } -> - let env = Env.module_env (A.ScopeName.path scope) env in Env.get_subscope_out_var env scope (Mark.remove var) | ToplevelVar { name } -> - let env = Env.module_env (A.TopdefName.path (Mark.remove name)) env in Env.get_toplevel_var env (Mark.remove name) in let ty = @@ -558,42 +544,39 @@ and typecheck_expr_top_down : "This is not a structure, cannot access field %s (%a)" field (format_typ ctx) (ty e_struct') in - let fld_ty = - let str = - try A.StructName.Map.find name env.structs - with A.StructName.Map.Not_found _ -> - Message.raise_spanned_error pos_e "No structure %a found" - A.StructName.format name - in - let field = - let ctx = Program.module_ctx ctx (A.StructName.path name) in - let candidate_structs = - try A.Ident.Map.find field ctx.ctx_struct_fields - with A.Ident.Map.Not_found _ -> - Message.raise_spanned_error - (Expr.mark_pos context_mark) - "Field @{\"%s\"@} does not belong to structure \ - @{\"%a\"@} (no structure defines it)" - field A.StructName.format name - in - try A.StructName.Map.find name candidate_structs - with A.StructName.Map.Not_found _ -> + let str = + try A.StructName.Map.find name env.structs + with A.StructName.Map.Not_found _ -> + Message.raise_spanned_error pos_e "No structure %a found" + A.StructName.format name + in + let field = + let candidate_structs = + try A.Ident.Map.find field ctx.ctx_struct_fields + with A.Ident.Map.Not_found _ -> Message.raise_spanned_error (Expr.mark_pos context_mark) - "@[Field @{\"%s\"@}@ does not belong to@ structure \ - @{\"%a\"@},@ but to %a@]" + "Field @{\"%s\"@} does not belong to structure \ + @{\"%a\"@} (no structure defines it)" field A.StructName.format name - (Format.pp_print_list - ~pp_sep:(fun ppf () -> Format.fprintf ppf "@ or@ ") - (fun fmt s_name -> - Format.fprintf fmt "@{\"%a\"@}" A.StructName.format - s_name)) - (A.StructName.Map.keys candidate_structs) in - A.StructField.Map.find field str + try A.StructName.Map.find name candidate_structs + with A.StructName.Map.Not_found _ -> + Message.raise_spanned_error + (Expr.mark_pos context_mark) + "@[Field @{\"%s\"@}@ does not belong to@ structure \ + @{\"%a\"@},@ but to %a@]" + field A.StructName.format name + (Format.pp_print_list + ~pp_sep:(fun ppf () -> Format.fprintf ppf "@ or@ ") + (fun fmt s_name -> + Format.fprintf fmt "@{\"%a\"@}" A.StructName.format + s_name)) + (A.StructName.Map.keys candidate_structs) in + let fld_ty = A.StructField.Map.find field str in let mark = mark_with_tau_and_unify fld_ty in - Expr.edstructaccess ~e:e_struct' ~name_opt:(Some name) ~field mark + Expr.estructaccess ~name ~e:e_struct' ~field mark | A.EStructAccess { e = e_struct; name; field } -> let fld_ty = let str = @@ -692,16 +675,11 @@ and typecheck_expr_top_down : in Expr.ematch ~e:e1' ~name ~cases mark | A.EScopeCall { scope; args } -> - let path = A.ScopeName.path scope in let scope_out_struct = - let ctx = Program.module_ctx ctx path in (A.ScopeName.Map.find scope ctx.ctx_scopes).out_struct_name in let mark = mark_with_tau_and_unify (unionfind (TStruct scope_out_struct)) in - let vars = - let env = Env.module_env path env in - A.ScopeName.Map.find scope env.scopes_input - in + let vars = A.ScopeName.Map.find scope env.scopes_input in let args' = A.ScopeVar.Map.mapi (fun name -> @@ -730,12 +708,6 @@ and typecheck_expr_top_down : in Expr.evar (Var.translate v) (mark_with_tau_and_unify tau') | A.EExternal { name } -> - let path = - match Mark.remove name with - | External_value td -> A.TopdefName.path td - | External_scope s -> A.ScopeName.path s - in - let ctx = Program.module_ctx ctx path in let ty = let not_found pr x = Message.raise_spanned_error pos_e diff --git a/compiler/shared_ast/typing.mli b/compiler/shared_ast/typing.mli index ba177106..3004740a 100644 --- a/compiler/shared_ast/typing.mli +++ b/compiler/shared_ast/typing.mli @@ -17,7 +17,6 @@ (** Typing for the default calculus. Because of the error terms, we perform type inference using the classical W algorithm with union-find unification. *) -open Catala_utils open Definitions module Env : sig @@ -35,8 +34,6 @@ module Env : sig 'e t -> 'e t - val add_module : ModuleName.t -> module_env:'e t -> 'e t -> 'e t - val module_env : Uid.Path.t -> 'e t -> 'e t val open_scope : ScopeName.t -> 'e t -> 'e t val dump : Format.formatter -> 'e t -> unit @@ -62,7 +59,10 @@ val expr : still done, but with unification with the existing annotations at every step. This can be used for double-checking after AST transformations and filling the gaps ([TAny]) if any. Use [Expr.untype] first if this is not - what you want. *) + what you want. + + Note that typing also transparently performs disambiguation of constructors: [EDStructAccess] nodes are translated into [EStructAccess] with the suitable structure and field idents (this only concerns [desugared] expressions). +*) val check_expr : leave_unresolved:bool -> diff --git a/compiler/surface/ast.ml b/compiler/surface/ast.ml index 451b1855..168e073f 100644 --- a/compiler/surface/ast.ml +++ b/compiler/surface/ast.ml @@ -312,15 +312,24 @@ and law_structure = | LawText of (string[@opaque]) | CodeBlock of code_block * source_repr * bool (* Metadata if true *) -and interface = uident Mark.pos * code_block -(** Invariant: an interface shall only contain [*Decl] elements, or [Topdef] - elements with [topdef_expr = None] *) +and interface = { + intf_modname: uident Mark.pos; + intf_code: code_block; + (** Invariant: an interface shall only contain [*Decl] elements, or [Topdef] + elements with [topdef_expr = None] *) + intf_submodules: module_use list; +} + +and module_use = { + mod_use_name: uident Mark.pos; + mod_use_alias: uident Mark.pos; +} and program = { program_module_name : uident Mark.pos option; program_items : law_structure list; program_source_files : (string[@opaque]) list; - program_modules : interface list; (** Modules being used by the program *) + program_used_modules : module_use list; program_lang : Cli.backend_lang; [@opaque] } diff --git a/compiler/surface/parser_driver.ml b/compiler/surface/parser_driver.ml index f5cee8ba..b84dfae6 100644 --- a/compiler/surface/parser_driver.ml +++ b/compiler/surface/parser_driver.ml @@ -248,10 +248,8 @@ let rec parse_source (lexbuf : Sedlexing.lexbuf) : Ast.program = let commands = localised_parser language lexbuf in let program = expand_includes source_file_name commands in { - program_module_name = program.Ast.program_module_name; - program_items = program.Ast.program_items; + program with program_source_files = source_file_name :: program.Ast.program_source_files; - program_modules = program.program_modules; program_lang = language; } @@ -278,10 +276,12 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : Ast.program_module_name = join_module_names (Some id); Ast.program_items = command :: acc.Ast.program_items; } - | Ast.ModuleUse (id, _alias) -> + | Ast.ModuleUse (mod_use_name, alias) -> + let mod_use_alias = Option.value ~default:mod_use_name alias in { acc with - Ast.program_modules = (id, []) :: acc.Ast.program_modules; + Ast.program_used_modules = { mod_use_name; mod_use_alias } + :: acc.Ast.program_used_modules; Ast.program_items = command :: acc.Ast.program_items; } | Ast.LawInclude (Ast.CatalaFile inc_file) -> @@ -301,8 +301,8 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : ] "A file that declares a module cannot be used through the raw \ '@{> Include@}' directive. You should use it as a \ - module with '@{> Use %a@}' instead." - Uid.Module.format (Uid.Module.of_string id) + module with '@{> Use @{%s@}@}' instead." + (Mark.remove id) in { Ast.program_module_name = acc.program_module_name; @@ -311,9 +311,9 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : acc.Ast.program_source_files; Ast.program_items = List.rev_append includ_program.program_items acc.Ast.program_items; - Ast.program_modules = - List.rev_append includ_program.program_modules - acc.Ast.program_modules; + Ast.program_used_modules = + List.rev_append includ_program.program_used_modules + acc.Ast.program_used_modules; Ast.program_lang = language; } | Ast.LawHeading (heading, commands') -> @@ -321,7 +321,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : Ast.program_module_name; Ast.program_items = commands'; Ast.program_source_files = new_sources; - Ast.program_modules = new_modules; + Ast.program_used_modules = new_used_modules; Ast.program_lang = _; } = expand_includes source_file commands' @@ -332,8 +332,8 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : List.rev_append new_sources acc.Ast.program_source_files; Ast.program_items = Ast.LawHeading (heading, commands') :: acc.Ast.program_items; - Ast.program_modules = - List.rev_append new_modules acc.Ast.program_modules; + Ast.program_used_modules = + List.rev_append new_used_modules acc.Ast.program_used_modules; Ast.program_lang = language; } | i -> { acc with Ast.program_items = i :: acc.Ast.program_items }) @@ -341,7 +341,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : Ast.program_module_name = None; Ast.program_source_files = []; Ast.program_items = []; - Ast.program_modules = []; + Ast.program_used_modules = []; Ast.program_lang = language; } commands @@ -351,7 +351,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) : Ast.program_module_name = rprg.Ast.program_module_name; Ast.program_source_files = List.rev rprg.Ast.program_source_files; Ast.program_items = List.rev rprg.Ast.program_items; - Ast.program_modules = List.rev rprg.Ast.program_modules; + Ast.program_used_modules = List.rev rprg.Ast.program_used_modules; } (** {2 Handling interfaces} *) @@ -360,7 +360,9 @@ let get_interface program = let rec filter (req, acc) = function | Ast.LawInclude _ | Ast.LawText _ | Ast.ModuleDef _ -> req, acc | Ast.LawHeading (_, str) -> List.fold_left filter (req, acc) str - | Ast.ModuleUse (m, _) -> m :: req, acc + | Ast.ModuleUse (mod_use_name, alias) -> + { Ast.mod_use_name; mod_use_alias = Option.value ~default:mod_use_name alias } + :: req, acc | Ast.CodeBlock (code, _, true) -> ( req, List.fold_left @@ -394,9 +396,17 @@ let with_sedlex_source source_file f = let load_interface source_file = let program = with_sedlex_source source_file parse_source in let modname = - match program.Ast.program_module_name with - | Some mname -> mname - | None -> + match program.Ast.program_module_name, source_file with + | Some (mname, pos), Cli.FileName file -> + if File.(equal mname Filename.(remove_extension (basename file))) + then mname, pos + else + Message.raise_spanned_error pos + "Module declared as @{%s@}, which does not match the file name %a" + mname + File.format file + | Some mname, _ -> mname + | None, _ -> Message.raise_error "%a doesn't define a module name. It should contain a '@{> \ Module %s@}' directive." @@ -408,7 +418,9 @@ let load_interface source_file = | _ -> "Module_name") in let used_modules, intf = get_interface program in - (modname, intf), used_modules + { Ast.intf_modname = modname; + Ast.intf_code = intf; + Ast.intf_submodules = used_modules; } let parse_top_level_file (source_file : Cli.input_src) : Ast.program = let program = with_sedlex_source source_file parse_source in diff --git a/compiler/surface/parser_driver.mli b/compiler/surface/parser_driver.mli index 9b19ae9e..091bde57 100644 --- a/compiler/surface/parser_driver.mli +++ b/compiler/surface/parser_driver.mli @@ -24,9 +24,9 @@ val lines : (** Raw file parser that doesn't interpret any includes and returns the flat law structure as is *) -val load_interface : Cli.input_src -> Ast.interface * string Mark.pos list +val load_interface : Cli.input_src -> Ast.interface (** Reads only declarations in metadata in the supplied input file, and only - keeps type information ; returns the modules used as well *) + keeps type information. The list of submodules is initialised with names only and empty contents. *) val parse_top_level_file : Cli.input_src -> Ast.program (** Parses a catala file (handling file includes) and returns a program. diff --git a/dune b/dune index 9c96ed4e..82204c15 100644 --- a/dune +++ b/dune @@ -10,7 +10,7 @@ ; don't stop building because of warnings (dev (flags - (:standard -warn-error -a))) + (:standard -warn-error -a -w -67))) ; for CI runs: must fail on warnings (check (flags diff --git a/tests/test_modules/good/mod_use2.catala_en b/tests/test_modules/good/mod_use2.catala_en index 3920042f..72a332ae 100644 --- a/tests/test_modules/good/mod_use2.catala_en +++ b/tests/test_modules/good/mod_use2.catala_en @@ -4,7 +4,7 @@ declaration scope T: t1 scope Mod_middle.S # input i content Enum1 - output o1 content Mod_def.S + output o1 content Mod_middle.Mod_def.S output o2 content money output o3 content money