Rework resolution of module elements

This changes the `decl_ctx` to be toplevel only, with flattened references to
uids for most elements. The module hierarchy, which is still useful in a few
places, is kept separately.

Module names are also changed to UIDs early on, and support for module aliases
has been added (needs testing).

This resolves some issues with lookup, and should be much more robust, as well
as more convenient for most lookups.

The `decl_ctx` was also extended for string ident lookups, which avoids having
to keep the desugared resolution structure available throughout the compilation
chain.
This commit is contained in:
Louis Gesbert 2023-11-20 16:01:06 +01:00
parent 86b7f80e90
commit 3649f92975
37 changed files with 860 additions and 987 deletions

View File

@ -545,6 +545,14 @@ let[@ocamlformat "disable"] static_base_rules =
"fi"; "fi";
] ]
~description:["<test>"; !output]; ~description:["<test>"; !output];
(* Note: this last rule looks horrible, but the processing is pretty simple:
in the rules above, we output the returning code of diffing individual
tests to a [<testfile>@test] file, then the rules for directories just
concat these files. What this last rule does is then just count the number
of `0` and the total number of characters in the file, and print a readable
message. Instead of this disgusting shell code embedded in the ninja file,
this could be a specialised subcommand of clerk, e.g. `clerk
test-diagnostic <results-file@test>` *)
] ]
let gen_build_statements let gen_build_statements
@ -641,7 +649,7 @@ let gen_build_statements
(if Filename.is_relative d then !Var.builddir / d else d); (if Filename.is_relative d then !Var.builddir / d else d);
]) ])
include_dirs include_dirs
@ (List.map (fun m -> m ^".cmx") modules) ); @ List.map (fun m -> m ^ ".cmx") modules );
] ]
in in
let expose_module = let expose_module =
@ -694,6 +702,7 @@ let gen_build_statements
diff; it should actually be an output for the cases when we diff; it should actually be an output for the cases when we
reset but that shouldn't cause trouble. *) reset but that shouldn't cause trouble. *)
Nj.build "post-test" ~inputs:[reference; test_out] Nj.build "post-test" ~inputs:[reference; test_out]
~implicit_in:["always"]
~outputs:[reference ^ "@post"] ~outputs:[reference ^ "@post"]
:: acc) :: acc)
[] item.legacy_tests [] item.legacy_tests
@ -720,7 +729,8 @@ let gen_build_statements
~outputs:[inc (srcv ^ "@test")] ~outputs:[inc (srcv ^ "@test")]
~inputs:[srcv; inc (srcv ^ "@out")] ~inputs:[srcv; inc (srcv ^ "@out")]
~implicit_in: ~implicit_in:
(List.map ("always" ::
List.map
(fun test -> legacy_test_reference test ^ "@post") (fun test -> legacy_test_reference test ^ "@post")
item.legacy_tests); item.legacy_tests);
results; results;
@ -801,7 +811,8 @@ let gen_ninja_file catala_exe catala_flags build_dir include_dirs dir =
@+ List.to_seq (base_bindings catala_exe catala_flags build_dir include_dirs) @+ List.to_seq (base_bindings catala_exe catala_flags build_dir include_dirs)
@+ Seq.return (Nj.Comment "\n- Base rules - #\n") @+ Seq.return (Nj.Comment "\n- Base rules - #\n")
@+ List.to_seq static_base_rules @+ List.to_seq static_base_rules
@+ Seq.return (Nj.Comment "- Project-specific build statements - #") @+ Seq.return (Nj.build "phony" ~outputs:["always"])
@+ Seq.return (Nj.Comment "\n- Project-specific build statements - #")
@+ build_statements include_dirs dir @+ build_statements include_dirs dir
@+ Seq.return (Nj.build "phony" ~outputs:["test"] ~inputs:[".@test"]) @+ Seq.return (Nj.build "phony" ~outputs:["test"] ~inputs:[".@test"])

View File

@ -36,6 +36,7 @@ module type S = sig
val keys : 'a t -> key list val keys : 'a t -> key list
val values : 'a t -> 'a list val values : 'a t -> 'a list
val of_list : (key * 'a) list -> 'a t val of_list : (key * 'a) list -> 'a t
val disjoint_union : 'a t -> 'a t -> 'a t
val format_keys : val format_keys :
?pp_sep:(Format.formatter -> unit -> unit) -> ?pp_sep:(Format.formatter -> unit -> unit) ->
@ -87,6 +88,12 @@ module Make (Ord : OrderedType) : S with type key = Ord.t = struct
let keys t = fold (fun k _ acc -> k :: acc) t [] |> List.rev let keys t = fold (fun k _ acc -> k :: acc) t [] |> List.rev
let values t = fold (fun _ v acc -> v :: acc) t [] |> List.rev let values t = fold (fun _ v acc -> v :: acc) t [] |> List.rev
let of_list l = List.fold_left (fun m (k, v) -> add k v m) empty l let of_list l = List.fold_left (fun m (k, v) -> add k v m) empty l
let disjoint_union t1 t2 =
union (fun k _ _ ->
Format.kasprintf failwith
"Maps are not disjoint: conflict on key %a"
Ord.format k)
t1 t2
let format_keys ?pp_sep ppf t = let format_keys ?pp_sep ppf t =
Format.pp_print_list ?pp_sep Ord.format ppf (keys t) Format.pp_print_list ?pp_sep Ord.format ppf (keys t)

View File

@ -32,6 +32,7 @@ module type Id = sig
val compare : t -> t -> int val compare : t -> t -> int
val equal : t -> t -> bool val equal : t -> t -> bool
val format : Format.formatter -> t -> unit val format : Format.formatter -> t -> unit
val to_string : t -> string
val hash : t -> int val hash : t -> int
module Set : Set.S with type elt = t module Set : Set.S with type elt = t
@ -68,6 +69,8 @@ module Make (X : Info) (S : Style) () : Id with type info = X.info = struct
let get_info (uid : t) : X.info = uid.info let get_info (uid : t) : X.info = uid.info
let hash (x : t) : int = x.id let hash (x : t) : int = x.id
let to_string t = X.to_string t.info
module Set = Set.Make (Ordering) module Set = Set.Make (Ordering)
module Map = Map.Make (Ordering) module Map = Map.Make (Ordering)
end end
@ -87,27 +90,12 @@ module Gen (S : Style) () = Make (MarkedString) (S) ()
(* - Modules, paths and qualified idents - *) (* - Modules, paths and qualified idents - *)
module Module = struct module Module =
module Ordering = struct Gen
type t = string Mark.pos (struct
let style = Ocolor_types.(Fg (C4 blue))
let equal = Mark.equal String.equal end)
let compare = Mark.compare String.compare ()
let format ppf m = Format.fprintf ppf "@{<blue>%s@}" (Mark.remove m)
end
include Ordering
let to_string m = Mark.remove m
let of_string m = m
let pos m = Mark.get m
module Set = Set.Make (Ordering)
module Map = Map.Make (Ordering)
end
(* TODO: should probably be turned into an uid once we implement module import
directives; that will incur an additional resolution work on all paths though
([module Module = Gen ()]) *)
module Path = struct module Path = struct
type t = Module.t list type t = Module.t list

View File

@ -47,6 +47,7 @@ module type Id = sig
val compare : t -> t -> int val compare : t -> t -> int
val equal : t -> t -> bool val equal : t -> t -> bool
val format : Format.formatter -> t -> unit val format : Format.formatter -> t -> unit
val to_string : t -> string
val hash : t -> int val hash : t -> int
module Set : Set.S with type elt = t module Set : Set.S with type elt = t
@ -62,27 +63,14 @@ end
(** This is the generative functor that ensures that two modules resulting from (** This is the generative functor that ensures that two modules resulting from
two different calls to [Make] will be viewed as different types [t] by the two different calls to [Make] will be viewed as different types [t] by the
OCaml typechecker. Prevents mixing up different sorts of identifiers. *) OCaml typechecker. Prevents mixing up different sorts of identifiers. *)
module Make (X : Info) (S : Style) () : Id with type info = X.info module Make (X : Info) (_ : Style) () : Id with type info = X.info
(** Shortcut for creating a kind of uids over marked strings *) (** Shortcut for creating a kind of uids over marked strings *)
module Gen (S : Style) () : Id with type info = MarkedString.info module Gen (_ : Style) () : Id with type info = MarkedString.info
(** {2 Handling of Uids with additional path information} *) (** {2 Handling of Uids with additional path information} *)
module Module : sig module Module : Id with type info = MarkedString.info
type t = private string Mark.pos
(* TODO: this will become an uid at some point *)
val to_string : t -> string
val format : Format.formatter -> t -> unit
val pos : t -> Pos.t
val equal : t -> t -> bool
val compare : t -> t -> int
val of_string : string * Pos.t -> t
module Set : Set.S with type elt = t
module Map : Map.S with type key = t
end
module Path : sig module Path : sig
type t = Module.t list type t = Module.t list
@ -94,7 +82,7 @@ module Path : sig
end end
(** Same as [Gen] but also registers path information *) (** Same as [Gen] but also registers path information *)
module Gen_qualified (S : Style) () : sig module Gen_qualified (_ : Style) () : sig
include Id with type info = Path.t * MarkedString.info include Id with type info = Path.t * MarkedString.info
val fresh : Path.t -> MarkedString.info -> t val fresh : Path.t -> MarkedString.info -> t

View File

@ -23,10 +23,10 @@ let () =
~input_src:(Contents (contents, "-inline-")) ~input_src:(Contents (contents, "-inline-"))
~language:(Some language) ~debug:false ~color:Never ~trace () ~language:(Some language) ~debug:false ~color:Never ~trace ()
in in
let prg, ctx, _type_order = let prg, _type_order =
Passes.dcalc options ~includes:[] ~optimize:false Passes.dcalc options ~includes:[] ~optimize:false
~check_invariants:false ~typed:Shared_ast.Expr.typed ~check_invariants:false ~typed:Shared_ast.Expr.typed
in in
Shared_ast.Interpreter.interpret_program_dcalc prg Shared_ast.Interpreter.interpret_program_dcalc prg
(Commands.get_scope_uid ctx scope) (Commands.get_scope_uid prg.decl_ctx scope)
end) end)

View File

@ -47,15 +47,10 @@ type 'm scope_sig_ctx = {
(** Mapping between the input scope variables and the input struct fields. *) (** Mapping between the input scope variables and the input struct fields. *)
} }
type 'm scope_sigs_ctx = {
scope_sigs : 'm scope_sig_ctx ScopeName.Map.t;
scope_sigs_modules : 'm scope_sigs_ctx ModuleName.Map.t;
}
type 'm ctx = { type 'm ctx = {
decl_ctx : decl_ctx; decl_ctx : decl_ctx;
scope_name : ScopeName.t option; scope_name : ScopeName.t option;
scopes_parameters : 'm scope_sigs_ctx; scopes_parameters : 'm scope_sig_ctx ScopeName.Map.t;
toplevel_vars : ('m Ast.expr Var.t * naked_typ) TopdefName.Map.t; toplevel_vars : ('m Ast.expr Var.t * naked_typ) TopdefName.Map.t;
scope_vars : scope_vars :
('m Ast.expr Var.t * naked_typ * Desugared.Ast.io) ScopeVar.Map.t; ('m Ast.expr Var.t * naked_typ * Desugared.Ast.io) ScopeVar.Map.t;
@ -77,14 +72,6 @@ let pos_mark_mk (type a m) (e : (a, m) gexpr) :
let pos_mark_as e = pos_mark (Mark.get e) in let pos_mark_as e = pos_mark (Mark.get e) in
pos_mark, pos_mark_as pos_mark, pos_mark_as
let module_scope_sig scope_sig_ctx scope =
let ssctx =
List.fold_left
(fun ssctx m -> ModuleName.Map.find m ssctx.scope_sigs_modules)
scope_sig_ctx (ScopeName.path scope)
in
ScopeName.Map.find scope ssctx.scope_sigs
let merge_defaults let merge_defaults
~(is_func : bool) ~(is_func : bool)
(caller : (dcalc, 'm) boxed_gexpr) (caller : (dcalc, 'm) boxed_gexpr)
@ -261,7 +248,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
Expr.ematch ~e:e1 ~name ~cases:d_cases m Expr.ematch ~e:e1 ~name ~cases:d_cases m
| EScopeCall { scope; args } -> | EScopeCall { scope; args } ->
let pos = Expr.mark_pos m in let pos = Expr.mark_pos m in
let sc_sig = module_scope_sig ctx.scopes_parameters scope in let sc_sig = ScopeName.Map.find scope ctx.scopes_parameters in
let in_var_map = let in_var_map =
ScopeVar.Map.merge ScopeVar.Map.merge
(fun var_name (str_field : scope_input_var_ctx option) expr -> (fun var_name (str_field : scope_input_var_ctx option) expr ->
@ -522,10 +509,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
|> SubScopeName.Map.find (Mark.remove alias) |> SubScopeName.Map.find (Mark.remove alias)
|> retrieve_in_and_out_typ_or_any var |> retrieve_in_and_out_typ_or_any var
| ELocation (ToplevelVar { name }) -> ( | ELocation (ToplevelVar { name }) -> (
let decl_ctx = let typ = TopdefName.Map.find (Mark.remove name) ctx.decl_ctx.ctx_topdefs in
Program.module_ctx ctx.decl_ctx (TopdefName.path (Mark.remove name))
in
let typ = TopdefName.Map.find (Mark.remove name) decl_ctx.ctx_topdefs in
match Mark.remove typ with match Mark.remove typ with
| TArrow (tin, (tout, _)) -> List.map Mark.remove tin, tout | TArrow (tin, (tout, _)) -> List.map Mark.remove tin, tout
| _ -> | _ ->
@ -735,10 +719,9 @@ let translate_rule
could be made more specific to avoid this case, but the added complexity could be made more specific to avoid this case, but the added complexity
didn't seem worth it *) didn't seem worth it *)
| Call (subname, subindex, m) -> | Call (subname, subindex, m) ->
let subscope_sig = module_scope_sig ctx.scopes_parameters subname in let subscope_sig = ScopeName.Map.find subname ctx.scopes_parameters in
let scope_sig_decl = let scope_sig_decl =
ScopeName.Map.find subname ScopeName.Map.find subname ctx.decl_ctx.ctx_scopes
(Program.module_ctx ctx.decl_ctx (ScopeName.path subname)).ctx_scopes
in in
let all_subscope_vars = subscope_sig.scope_sig_local_vars in let all_subscope_vars = subscope_sig.scope_sig_local_vars in
let all_subscope_input_vars = let all_subscope_input_vars =
@ -968,7 +951,7 @@ let translate_scope_decl
(sigma : 'm Scopelang.Ast.scope_decl) = (sigma : 'm Scopelang.Ast.scope_decl) =
let sigma_info = ScopeName.get_info sigma.scope_decl_name in let sigma_info = ScopeName.get_info sigma.scope_decl_name in
let scope_sig = let scope_sig =
ScopeName.Map.find sigma.scope_decl_name ctx.scopes_parameters.scope_sigs ScopeName.Map.find sigma.scope_decl_name ctx.scopes_parameters
in in
let scope_variables = scope_sig.scope_sig_local_vars in let scope_variables = scope_sig.scope_sig_local_vars in
let ctx = { ctx with scope_name = Some scope_name } in let ctx = { ctx with scope_name = Some scope_name } in
@ -1088,8 +1071,8 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
Scopelang.Dependency.get_defs_ordering defs_dependencies Scopelang.Dependency.get_defs_ordering defs_dependencies
in in
let decl_ctx = prgm.program_ctx in let decl_ctx = prgm.program_ctx in
let sctx : 'm scope_sigs_ctx = let scopes_parameters : 'm scope_sig_ctx ScopeName.Map.t =
let process_scope_sig scope_name scope = let process_scope_sig decl_ctx scope_name scope =
let scope_path = ScopeName.path scope_name in let scope_path = ScopeName.path scope_name in
let scope_ref = let scope_ref =
if scope_path = [] then if scope_path = [] then
@ -1100,13 +1083,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
(Mark.copy (ScopeName.get_info scope_name) scope_name) (Mark.copy (ScopeName.get_info scope_name) scope_name)
in in
let scope_info = let scope_info =
try ScopeName.Map.find scope_name decl_ctx.ctx_scopes
ScopeName.Map.find scope_name
(Program.module_ctx decl_ctx scope_path).ctx_scopes
with ScopeName.Map.Not_found _ ->
Message.raise_spanned_error
(Mark.get (ScopeName.get_info scope_name))
"Could not find scope %a" ScopeName.format scope_name
in in
let scope_sig_in_fields = let scope_sig_in_fields =
(* Output fields have already been generated and added to the program (* Output fields have already been generated and added to the program
@ -1154,69 +1131,45 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
scope_sig_in_fields; scope_sig_in_fields;
} }
in in
let rec process_modules prg = let process_scopes scopes =
{ ScopeName.Map.mapi
scope_sigs = (fun scope_name (scope_decl, _) ->
ScopeName.Map.mapi process_scope_sig decl_ctx scope_name scope_decl)
(fun scope_name (scope_decl, _) -> scopes
process_scope_sig scope_name scope_decl)
prg.Scopelang.Ast.program_scopes;
scope_sigs_modules =
ModuleName.Map.map process_modules prg.Scopelang.Ast.program_modules;
}
in in
{ ModuleName.Map.fold (fun _ s ->
scope_sigs = ScopeName.Map.disjoint_union
ScopeName.Map.mapi (process_scopes s))
(fun scope_name (scope_decl, _) -> prgm.Scopelang.Ast.program_modules
process_scope_sig scope_name scope_decl) (process_scopes prgm.Scopelang.Ast.program_scopes)
prgm.Scopelang.Ast.program_scopes;
scope_sigs_modules =
ModuleName.Map.map process_modules prgm.Scopelang.Ast.program_modules;
}
in in
let add_scope_in_structs scope_sigs structs = let ctx_structs =
ScopeName.Map.fold ScopeName.Map.fold
(fun _ scope_sig_ctx acc -> (fun _ scope_sig_ctx acc ->
let fields = let fields =
ScopeVar.Map.fold ScopeVar.Map.fold
(fun _ sivc acc -> (fun _ sivc acc ->
let pos = Mark.get (StructField.get_info sivc.scope_input_name) in let pos = Mark.get (StructField.get_info sivc.scope_input_name) in
StructField.Map.add sivc.scope_input_name StructField.Map.add sivc.scope_input_name
(sivc.scope_input_typ, pos) (sivc.scope_input_typ, pos)
acc) acc)
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
in in
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc) StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
scope_sigs.scope_sigs structs scopes_parameters decl_ctx.ctx_structs
in in
let rec gather_module_in_structs acc sctx = let decl_ctx = { decl_ctx with ctx_structs } in
(* Expose all added in_structs from submodules at toplevel *) let toplevel_vars =
ModuleName.Map.fold TopdefName.Map.mapi
(fun _ scope_sigs acc -> (fun name (_, ty) ->
add_scope_in_structs scope_sigs Var.make (Mark.remove (TopdefName.get_info name)), Mark.remove ty)
(gather_module_in_structs acc scope_sigs.scope_sigs_modules)) prgm.Scopelang.Ast.program_topdefs
sctx acc
in in
let decl_ctx = let ctx =
{
decl_ctx with
ctx_structs =
add_scope_in_structs sctx
(gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules);
}
in
let top_ctx =
let toplevel_vars =
TopdefName.Map.mapi
(fun name (_, ty) ->
Var.make (Mark.remove (TopdefName.get_info name)), Mark.remove ty)
prgm.Scopelang.Ast.program_topdefs
in
{ {
decl_ctx; decl_ctx;
scope_name = None; scope_name = None;
scopes_parameters = sctx; scopes_parameters;
scope_vars = ScopeVar.Map.empty; scope_vars = ScopeVar.Map.empty;
subscope_vars = SubScopeName.Map.empty; subscope_vars = SubScopeName.Map.empty;
toplevel_vars; toplevel_vars;
@ -1226,7 +1179,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
(* the resulting expression is the list of definitions of all the scopes, (* the resulting expression is the list of definitions of all the scopes,
ending with the top-level scope. The decl_ctx is filled in left-to-right ending with the top-level scope. The decl_ctx is filled in left-to-right
order, then the chained scopes aggregated from the right. *) order, then the chained scopes aggregated from the right. *)
let rec translate_defs ctx = function let rec translate_defs = function
| [] -> Bindlib.box Nil | [] -> Bindlib.box Nil
| def :: next -> | def :: next ->
let dvar, def = let dvar, def =
@ -1245,7 +1198,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
in in
let scope_var = let scope_var =
match match
(ScopeName.Map.find scope_name sctx.scope_sigs) (ScopeName.Map.find scope_name scopes_parameters)
.scope_sig_scope_ref .scope_sig_scope_ref
with with
| Local_scope_ref v -> v | Local_scope_ref v -> v
@ -1256,13 +1209,13 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
(fun body -> ScopeDef (scope_name, body)) (fun body -> ScopeDef (scope_name, body))
scope_body ) scope_body )
in in
let scope_next = translate_defs ctx next in let scope_next = translate_defs next in
let next_bind = Bindlib.bind_var dvar scope_next in let next_bind = Bindlib.bind_var dvar scope_next in
Bindlib.box_apply2 Bindlib.box_apply2
(fun item next_bind -> Cons (item, next_bind)) (fun item next_bind -> Cons (item, next_bind))
def next_bind def next_bind
in in
let items = translate_defs top_ctx defs_ordering in let items = translate_defs defs_ordering in
Expr.Box.assert_closed items; Expr.Box.assert_closed items;
{ {
code_items = Bindlib.unbox items; code_items = Bindlib.unbox items;

View File

@ -228,12 +228,16 @@ type scope = {
scope_meta_assertions : meta_assertion list; scope_meta_assertions : meta_assertion list;
} }
type modul = {
module_scopes : scope ScopeName.Map.t;
module_topdefs : (expr option * typ) TopdefName.Map.t;
}
type program = { type program = {
program_module_name : ModuleName.t option; program_module_name : Ident.t Mark.pos option;
program_scopes : scope ScopeName.Map.t;
program_topdefs : (expr option * typ) TopdefName.Map.t;
program_ctx : decl_ctx; program_ctx : decl_ctx;
program_modules : program ModuleName.Map.t; program_modules : modul ModuleName.Map.t;
program_root : modul;
program_lang : Cli.backend_lang; program_lang : Cli.backend_lang;
} }
@ -299,8 +303,8 @@ let fold_exprs ~(f : 'a -> expr -> 'a) ~(init : 'a) (p : program) : 'a =
scope.scope_assertions acc scope.scope_assertions acc
in in
acc) acc)
p.program_scopes init p.program_root.module_scopes init
in in
TopdefName.Map.fold TopdefName.Map.fold
(fun _ (e, _) acc -> Option.fold ~none:acc ~some:(f acc) e) (fun _ (e, _) acc -> Option.fold ~none:acc ~some:(f acc) e)
p.program_topdefs acc p.program_root.module_topdefs acc

View File

@ -93,6 +93,7 @@ type io = {
type scope_def = { type scope_def = {
scope_def_rules : rule RuleName.Map.t; scope_def_rules : rule RuleName.Map.t;
(** empty outside of the root module *)
scope_def_typ : typ; scope_def_typ : typ;
scope_def_parameters : scope_def_parameters :
(Uid.MarkedString.info * Shared_ast.typ) list Mark.pos option; (Uid.MarkedString.info * Shared_ast.typ) list Mark.pos option;
@ -108,16 +109,22 @@ type scope = {
scope_uid : ScopeName.t; scope_uid : ScopeName.t;
scope_defs : scope_def ScopeDef.Map.t; scope_defs : scope_def ScopeDef.Map.t;
scope_assertions : assertion AssertionName.Map.t; scope_assertions : assertion AssertionName.Map.t;
(** empty outside of the root module *)
scope_options : catala_option Mark.pos list; scope_options : catala_option Mark.pos list;
scope_meta_assertions : meta_assertion list; scope_meta_assertions : meta_assertion list;
} }
type modul = {
module_scopes : scope ScopeName.Map.t;
module_topdefs : (expr option * typ) TopdefName.Map.t;
(** the expr is [None] outside of the root module *)
}
type program = { type program = {
program_module_name : ModuleName.t option; program_module_name : Ident.t Mark.pos option;
program_scopes : scope ScopeName.Map.t;
program_topdefs : (expr option * typ) TopdefName.Map.t;
program_ctx : decl_ctx; program_ctx : decl_ctx;
program_modules : program ModuleName.Map.t; program_modules : modul ModuleName.Map.t; (** Contains all submodules of the program, in a flattened structure *)
program_root : modul;
program_lang : Cli.backend_lang; program_lang : Cli.backend_lang;
} }

View File

@ -64,53 +64,45 @@ let scope ctx env scope =
let program prg = let program prg =
(* Caution: this environment building code is very similar to that in (* Caution: this environment building code is very similar to that in
scopelang/ast.ml. Any edits should probably be reflected. *) scopelang/ast.ml. Any edits should probably be reflected. *)
let base_typing_env prg = let env = Typing.Env.empty prg.program_ctx in
let env = Typing.Env.empty prg.program_ctx in let env =
let env = TopdefName.Map.fold
TopdefName.Map.fold (fun name ty env -> Typing.Env.add_toplevel_var name ty env)
(fun name (_e, ty) env -> Typing.Env.add_toplevel_var name ty env) prg.program_ctx.ctx_topdefs env
prg.program_topdefs env in
in let env =
let env = ScopeName.Map.fold
ScopeName.Map.fold (fun scope_name _info env ->
(fun scope_name scope env -> let modul =
let vars = List.fold_left
ScopeDef.Map.fold (fun _ m -> ModuleName.Map.find m prg.program_modules)
(fun var def vars -> prg.program_root (ScopeName.path scope_name)
in
let scope = ScopeName.Map.find scope_name modul.module_scopes in
let vars =
ScopeDef.Map.fold
(fun var def vars ->
match var with match var with
| Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars | Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars
| SubScopeVar _ -> vars) | SubScopeVar _ -> vars)
scope.scope_defs ScopeVar.Map.empty scope.scope_defs ScopeVar.Map.empty
in in
(* at this stage, rule resolution and the corresponding encapsulation (* at this stage, rule resolution and the corresponding encapsulation
into default terms hasn't taken place, so input and output into default terms hasn't taken place, so input and output
variables don't need different typing *) variables don't need different typing *)
Typing.Env.add_scope scope_name ~vars ~in_vars:vars env) Typing.Env.add_scope scope_name ~vars ~in_vars:vars env)
prg.program_scopes env prg.program_ctx.ctx_scopes env
in
env
in in
let rec build_typing_env prg = let module_topdefs =
ModuleName.Map.fold
(fun modname prg ->
Typing.Env.add_module modname ~module_env:(build_typing_env prg))
prg.program_modules (base_typing_env prg)
in
let env =
ModuleName.Map.fold
(fun modname prg ->
Typing.Env.add_module modname ~module_env:(build_typing_env prg))
prg.program_modules (base_typing_env prg)
in
let program_topdefs =
TopdefName.Map.map TopdefName.Map.map
(function (function
| Some e, ty -> | Some e, ty ->
Some (Expr.unbox (expr prg.program_ctx env (Expr.box e))), ty Some (Expr.unbox (expr prg.program_ctx env (Expr.box e))), ty
| None, ty -> None, ty) | None, ty -> None, ty)
prg.program_topdefs prg.program_root.module_topdefs
in in
let program_scopes = let module_scopes =
ScopeName.Map.map (scope prg.program_ctx env) prg.program_scopes ScopeName.Map.map (scope prg.program_ctx env)
prg.program_root.module_scopes
in in
{ prg with program_topdefs; program_scopes } { prg with program_root = { module_topdefs; module_scopes } }

View File

@ -123,7 +123,7 @@ let translate_unop (op : S.unop) pos : Ast.expr boxed =
let raise_error_cons_not_found let raise_error_cons_not_found
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(constructor : string Mark.pos) = (constructor : string Mark.pos) =
let constructors = Ident.Map.keys ctxt.constructor_idmap in let constructors = Ident.Map.keys ctxt.local.constructor_idmap in
let closest_constructors = let closest_constructors =
Suggestions.suggestion_minimum_levenshtein_distance_association constructors Suggestions.suggestion_minimum_levenshtein_distance_association constructors
(Mark.remove constructor) (Mark.remove constructor)
@ -146,7 +146,7 @@ let rec disambiguate_constructor
"The deep pattern matching syntactic sugar is not yet supported" "The deep pattern matching syntactic sugar is not yet supported"
in in
let possible_c_uids = let possible_c_uids =
try Ident.Map.find (Mark.remove constructor) ctxt.constructor_idmap try Ident.Map.find (Mark.remove constructor) ctxt.local.constructor_idmap
with Ident.Map.Not_found _ -> raise_error_cons_not_found ctxt constructor with Ident.Map.Not_found _ -> raise_error_cons_not_found ctxt constructor
in in
match path with match path with
@ -168,17 +168,13 @@ let rec disambiguate_constructor
with EnumName.Map.Not_found _ -> with EnumName.Map.Not_found _ ->
Message.raise_spanned_error pos "Enum %s does not contain case %s" Message.raise_spanned_error pos "Enum %s does not contain case %s"
(Mark.remove enum) (Mark.remove constructor)) (Mark.remove enum) (Mark.remove constructor))
| modname :: path -> ( | mod_id :: path ->
let modname = ModuleName.of_string modname in let constructor =
match ModuleName.Map.find_opt modname ctxt.modules with List.map (Mark.map (fun (_, c) -> path, c)) constructor0
| None -> in
Message.raise_spanned_error (ModuleName.pos modname) disambiguate_constructor
"Module \"%a\" not found" ModuleName.format modname (Name_resolution.get_module_ctx ctxt mod_id)
| Some ctxt -> constructor pos
let constructor =
List.map (Mark.map (fun (_, c) -> path, c)) constructor0
in
disambiguate_constructor ctxt constructor pos)
let int100 = Runtime.integer_of_int 100 let int100 = Runtime.integer_of_int 100
let rat100 = Runtime.decimal_of_integer int100 let rat100 = Runtime.decimal_of_integer int100
@ -370,7 +366,7 @@ let rec translate_expr
(* Note: allowing access to a global variable with the same name as a (* Note: allowing access to a global variable with the same name as a
subscope is disputable, but I see no good reason to forbid it either *) subscope is disputable, but I see no good reason to forbid it either *)
| None -> ( | None -> (
match Ident.Map.find_opt x ctxt.topdefs with match Ident.Map.find_opt x ctxt.local.topdefs with
| Some v -> | Some v ->
Expr.elocation Expr.elocation
(ToplevelVar { name = v, Mark.get (TopdefName.get_info v) }) (ToplevelVar { name = v, Mark.get (TopdefName.get_info v) })
@ -380,7 +376,7 @@ let rec translate_expr
"for a local, scope-wide or global variable" (x, pos)))) "for a local, scope-wide or global variable" (x, pos))))
| Ident (path, name) -> ( | Ident (path, name) -> (
let ctxt = Name_resolution.module_ctx ctxt path in let ctxt = Name_resolution.module_ctx ctxt path in
match Ident.Map.find_opt (Mark.remove name) ctxt.topdefs with match Ident.Map.find_opt (Mark.remove name) ctxt.local.topdefs with
| Some v -> | Some v ->
Expr.elocation Expr.elocation
(ToplevelVar { name = v, Mark.get (TopdefName.get_info v) }) (ToplevelVar { name = v, Mark.get (TopdefName.get_info v) })
@ -415,13 +411,8 @@ let rec translate_expr
let rec get_str ctxt = function let rec get_str ctxt = function
| [] -> None | [] -> None
| [c] -> Some (Name_resolution.get_struct ctxt c) | [c] -> Some (Name_resolution.get_struct ctxt c)
| modname :: path -> ( | mod_id :: path ->
let modname = ModuleName.of_string modname in get_str (Name_resolution.get_module_ctx ctxt mod_id) path
match ModuleName.Map.find_opt modname ctxt.modules with
| None ->
Message.raise_spanned_error (ModuleName.pos modname)
"Module \"%a\" not found" ModuleName.format modname
| Some ctxt -> get_str ctxt path)
in in
Expr.edstructaccess ~e ~field:(Mark.remove x) Expr.edstructaccess ~e ~field:(Mark.remove x)
~name_opt:(get_str ctxt path) emark) ~name_opt:(get_str ctxt path) emark)
@ -478,7 +469,7 @@ let rec translate_expr
| StructLit (((path, s_name), _), fields) -> | StructLit (((path, s_name), _), fields) ->
let ctxt = Name_resolution.module_ctx ctxt path in let ctxt = Name_resolution.module_ctx ctxt path in
let s_uid = let s_uid =
match Ident.Map.find_opt (Mark.remove s_name) ctxt.typedefs with match Ident.Map.find_opt (Mark.remove s_name) ctxt.local.typedefs with
| Some (Name_resolution.TStruct s_uid) -> s_uid | Some (Name_resolution.TStruct s_uid) -> s_uid
| _ -> | _ ->
Message.raise_spanned_error (Mark.get s_name) Message.raise_spanned_error (Mark.get s_name)
@ -490,7 +481,7 @@ let rec translate_expr
let f_uid = let f_uid =
try try
StructName.Map.find s_uid StructName.Map.find s_uid
(Ident.Map.find (Mark.remove f_name) ctxt.field_idmap) (Ident.Map.find (Mark.remove f_name) ctxt.local.field_idmap)
with StructName.Map.Not_found _ | Ident.Map.Not_found _ -> with StructName.Map.Not_found _ | Ident.Map.Not_found _ ->
Message.raise_spanned_error (Mark.get f_name) Message.raise_spanned_error (Mark.get f_name)
"This identifier should refer to a field of struct %s" "This identifier should refer to a field of struct %s"
@ -518,7 +509,7 @@ let rec translate_expr
Expr.estruct ~name:s_uid ~fields:s_fields emark Expr.estruct ~name:s_uid ~fields:s_fields emark
| EnumInject (((path, (constructor, pos_constructor)), _), payload) -> ( | EnumInject (((path, (constructor, pos_constructor)), _), payload) -> (
let get_possible_c_uids ctxt = let get_possible_c_uids ctxt =
try Ident.Map.find constructor ctxt.Name_resolution.constructor_idmap try Ident.Map.find constructor ctxt.Name_resolution.local.constructor_idmap
with Ident.Map.Not_found _ -> with Ident.Map.Not_found _ ->
raise_error_cons_not_found ctxt (constructor, pos_constructor) raise_error_cons_not_found ctxt (constructor, pos_constructor)
in in
@ -1027,7 +1018,7 @@ let process_def
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(prgm : Ast.program) (prgm : Ast.program)
(def : S.definition) : Ast.program = (def : S.definition) : Ast.program =
let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_root.module_scopes in
let scope_ctxt = ScopeName.Map.find scope_uid ctxt.scopes in let scope_ctxt = ScopeName.Map.find scope_uid ctxt.scopes in
let def_key = let def_key =
Name_resolution.get_def_key Name_resolution.get_def_key
@ -1091,10 +1082,13 @@ let process_def
scope_defs = Ast.ScopeDef.Map.add def_key scope_def scope.scope_defs; scope_defs = Ast.ScopeDef.Map.add def_key scope_def scope.scope_defs;
} }
in in
let module_scopes =
ScopeName.Map.add scope_uid scope_updated
prgm.program_root.module_scopes
in
{ {
prgm with prgm with
program_scopes = program_root = { prgm.program_root with module_scopes }
ScopeName.Map.add scope_uid scope_updated prgm.program_scopes;
} }
(** Translates a {!type: S.rule} from the surface language *) (** Translates a {!type: S.rule} from the surface language *)
@ -1114,7 +1108,7 @@ let process_assert
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(prgm : Ast.program) (prgm : Ast.program)
(ass : S.assertion) : Ast.program = (ass : S.assertion) : Ast.program =
let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_root.module_scopes in
let ass = let ass =
translate_expr (Some scope_uid) None ctxt Ident.Map.empty translate_expr (Some scope_uid) None ctxt Ident.Map.empty
(match ass.S.assertion_condition with (match ass.S.assertion_condition with
@ -1146,9 +1140,11 @@ let process_assert
scope.scope_assertions; scope.scope_assertions;
} }
in in
let module_scopes = ScopeName.Map.add scope_uid new_scope prgm.program_root.module_scopes
in
{ {
prgm with prgm with
program_scopes = ScopeName.Map.add scope_uid new_scope prgm.program_scopes; program_root = { prgm.program_root with module_scopes }
} }
(** Translates a surface definition, rule or assertion *) (** Translates a surface definition, rule or assertion *)
@ -1167,7 +1163,7 @@ let process_scope_use_item
| S.Assertion ass -> process_assert precond scope ctxt prgm ass | S.Assertion ass -> process_assert precond scope ctxt prgm ass
| S.DateRounding (r, _) -> | S.DateRounding (r, _) ->
let scope_uid = scope in let scope_uid = scope in
let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_root.module_scopes in
let r = let r =
match r with match r with
| S.Increasing -> Ast.Increasing | S.Increasing -> Ast.Increasing
@ -1192,9 +1188,10 @@ let process_scope_use_item
Mark.copy item (Ast.DateRounding r) :: scope.scope_options; Mark.copy item (Ast.DateRounding r) :: scope.scope_options;
} }
in in
let module_scopes = ScopeName.Map.add scope_uid new_scope prgm.program_root.module_scopes in
{ {
prgm with prgm with
program_scopes = ScopeName.Map.add scope_uid new_scope prgm.program_scopes; program_root = { prgm.program_root with module_scopes }
} }
| _ -> prgm | _ -> prgm
@ -1254,7 +1251,7 @@ let process_scope_use
let scope_uid = Name_resolution.get_scope ctxt use.scope_use_name in let scope_uid = Name_resolution.get_scope ctxt use.scope_use_name in
(* Make sure the scope exists *) (* Make sure the scope exists *)
let prgm = let prgm =
match ScopeName.Map.find_opt scope_uid prgm.program_scopes with match ScopeName.Map.find_opt scope_uid prgm.program_root.module_scopes with
| Some _ -> prgm | Some _ -> prgm
| None -> assert false | None -> assert false
(* should not happen *) (* should not happen *)
@ -1270,7 +1267,7 @@ let process_topdef
(prgm : Ast.program) (prgm : Ast.program)
(def : S.top_def) : Ast.program = (def : S.top_def) : Ast.program =
let id = let id =
Ident.Map.find (Mark.remove def.S.topdef_name) ctxt.Name_resolution.topdefs Ident.Map.find (Mark.remove def.S.topdef_name) ctxt.Name_resolution.local.topdefs
in in
let translate_typ t = Name_resolution.process_type ctxt t in let translate_typ t = Name_resolution.process_type ctxt t in
let translate_tbase (tbase, m) = translate_typ (Base tbase, m) in let translate_tbase (tbase, m) = translate_typ (Base tbase, m) in
@ -1300,7 +1297,7 @@ let process_topdef
in in
Some (Expr.unbox_closed e) Some (Expr.unbox_closed e)
in in
let program_topdefs = let module_topdefs =
TopdefName.Map.update id TopdefName.Map.update id
(fun def0 -> (fun def0 ->
match def0, expr_opt with match def0, expr_opt with
@ -1318,9 +1315,9 @@ let process_topdef
| Some _, Some _ -> err "Multiple definitions" | Some _, Some _ -> err "Multiple definitions"
| Some e, None -> Some (Some e, typ) | Some e, None -> Some (Some e, typ)
| None, Some e -> Some (Some e, ty0))) | None, Some e -> Some (Some e, ty0)))
prgm.Ast.program_topdefs prgm.Ast.program_root.module_topdefs
in in
{ prgm with Ast.program_topdefs } { prgm with program_root = { prgm.program_root with module_topdefs } }
let attribute_to_io (attr : S.scope_decl_context_io) : Ast.io = let attribute_to_io (attr : S.scope_decl_context_io) : Ast.io =
{ {
@ -1337,13 +1334,13 @@ let attribute_to_io (attr : S.scope_decl_context_io) : Ast.io =
let init_scope_defs let init_scope_defs
(ctxt : Name_resolution.context) (ctxt : Name_resolution.context)
(scope_idmap : Name_resolution.scope_var_or_subscope Ident.Map.t) : (scope_idmap : scope_var_or_subscope Ident.Map.t) :
Ast.scope_def Ast.ScopeDef.Map.t = Ast.scope_def Ast.ScopeDef.Map.t =
(* Initializing the definitions of all scopes and subscope vars, with no rules (* Initializing the definitions of all scopes and subscope vars, with no rules
yet inside *) yet inside *)
let add_def _ v scope_def_map = let add_def _ v scope_def_map =
match v with match v with
| Name_resolution.ScopeVar v -> ( | ScopeVar v -> (
let v_sig = ScopeVar.Map.find v ctxt.Name_resolution.var_typs in let v_sig = ScopeVar.Map.find v ctxt.Name_resolution.var_typs in
match v_sig.var_sig_states_list with match v_sig.var_sig_states_list with
| [] -> | [] ->
@ -1389,19 +1386,20 @@ let init_scope_defs
(scope_def_map, 0) states (scope_def_map, 0) states
in in
scope_def) scope_def)
| Name_resolution.SubScope (v0, subscope_uid) -> | SubScope (v0, subscope_uid) ->
let sub_scope_def = Name_resolution.get_scope_context ctxt subscope_uid in let sub_scope_def = Name_resolution.get_scope_context ctxt subscope_uid in
let ctxt = let ctxt =
List.fold_left List.fold_left
(fun ctx m -> ModuleName.Map.find m ctx.Name_resolution.modules) (fun ctx m ->
{ ctxt with local = ModuleName.Map.find m ctx.Name_resolution.modules })
ctxt ctxt
(ScopeName.path subscope_uid) (ScopeName.path subscope_uid)
in in
Ident.Map.fold Ident.Map.fold
(fun _ v scope_def_map -> (fun _ v scope_def_map ->
match v with match v with
| Name_resolution.SubScope _ -> scope_def_map | SubScope _ -> scope_def_map
| Name_resolution.ScopeVar v -> | ScopeVar v ->
(* TODO: shouldn't we ignore internal variables too at this point (* TODO: shouldn't we ignore internal variables too at this point
? *) ? *)
let v_sig = ScopeVar.Map.find v ctxt.Name_resolution.var_typs in let v_sig = ScopeVar.Map.find v ctxt.Name_resolution.var_typs in
@ -1424,91 +1422,110 @@ let init_scope_defs
(** Main function of this module *) (** Main function of this module *)
let translate_program (ctxt : Name_resolution.context) (surface : S.program) : let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
Ast.program = Ast.program =
let top_ctx = ctxt in let get_scope s_uid =
let desugared = let s_context = ScopeName.Map.find s_uid ctxt.scopes in
let get_program_scopes ctxt = let scope_vars =
ScopeName.Map.mapi Ident.Map.fold
(fun s_uid s_context -> (fun _ v acc ->
let scope_vars = match v with
Ident.Map.fold | SubScope _ -> acc
(fun _ v acc -> | ScopeVar v -> (
match v with let v_sig =
| Name_resolution.SubScope _ -> acc ScopeVar.Map.find v ctxt.Name_resolution.var_typs
| Name_resolution.ScopeVar v -> ( in
let v_sig = match v_sig.Name_resolution.var_sig_states_list with
ScopeVar.Map.find v ctxt.Name_resolution.var_typs | [] -> ScopeVar.Map.add v Ast.WholeVar acc
in | states -> ScopeVar.Map.add v (Ast.States states) acc))
match v_sig.Name_resolution.var_sig_states_list with s_context.Name_resolution.var_idmap ScopeVar.Map.empty
| [] -> ScopeVar.Map.add v Ast.WholeVar acc
| states -> ScopeVar.Map.add v (Ast.States states) acc))
s_context.Name_resolution.var_idmap ScopeVar.Map.empty
in
let scope_sub_scopes =
Ident.Map.fold
(fun _ v acc ->
match v with
| Name_resolution.ScopeVar _ -> acc
| Name_resolution.SubScope (sub_var, sub_scope) ->
SubScopeName.Map.add sub_var sub_scope acc)
s_context.Name_resolution.var_idmap SubScopeName.Map.empty
in
{
Ast.scope_vars;
scope_sub_scopes;
scope_defs = init_scope_defs top_ctx s_context.var_idmap;
scope_assertions = Ast.AssertionName.Map.empty;
scope_meta_assertions = [];
scope_options = [];
scope_uid = s_uid;
})
ctxt.Name_resolution.scopes
in in
let rec make_ctx ctxt = let scope_sub_scopes =
let submodules = Ident.Map.fold
ModuleName.Map.map make_ctx ctxt.Name_resolution.modules (fun _ v acc ->
match v with
| ScopeVar _ -> acc
| SubScope (sub_var, sub_scope) ->
SubScopeName.Map.add sub_var sub_scope acc)
s_context.Name_resolution.var_idmap SubScopeName.Map.empty
in
{
Ast.scope_vars;
scope_sub_scopes;
scope_defs = init_scope_defs ctxt s_context.var_idmap;
scope_assertions = Ast.AssertionName.Map.empty;
scope_meta_assertions = [];
scope_options = [];
scope_uid = s_uid;
}
in
let get_scopes mctx =
Ident.Map.fold (fun _ tydef acc -> match tydef with
| Name_resolution.TScope (s_uid, _) ->
ScopeName.Map.add s_uid (get_scope s_uid) acc
| _ -> acc)
mctx.Name_resolution.typedefs ScopeName.Map.empty;
in
let program_modules =
ModuleName.Map.map (fun mctx ->
{ Ast.module_scopes = get_scopes mctx;
Ast.module_topdefs =
Ident.Map.fold (fun _ name acc ->
TopdefName.Map.add name
(None,
TopdefName.Map.find name ctxt.Name_resolution.topdef_types)
acc;
)
mctx.topdefs TopdefName.Map.empty
})
ctxt.modules
in
let program_ctx =
let open Name_resolution in
let ctx_scopes mctx acc =
Ident.Map.fold (fun _ tydef acc ->
match tydef with
| TScope (s_uid, info) ->
ScopeName.Map.add s_uid info acc
| _ -> acc)
mctx.Name_resolution.typedefs acc
in
let ctx_modules =
let rec aux mctx =
Ident.Map.fold (fun _ m (M acc) ->
let sub = aux (ModuleName.Map.find m ctxt.modules) in
M (ModuleName.Map.add m sub acc))
mctx.used_modules (M ModuleName.Map.empty)
in in
{ aux ctxt.local
Ast.program_lang = surface.program_lang;
Ast.program_module_name =
Option.map ModuleName.of_string
surface.Surface.Ast.program_module_name;
Ast.program_ctx =
{
(* After name resolution, type definitions (structs and enums) are
exposed at toplevel for easier lookup *)
ctx_structs =
ModuleName.Map.fold
(fun _ prg acc ->
StructName.Map.union
(fun _ _ _ -> assert false)
acc prg.Ast.program_ctx.ctx_structs)
submodules ctxt.Name_resolution.structs;
ctx_enums =
ModuleName.Map.fold
(fun _ prg acc ->
EnumName.Map.union
(fun _ _ _ -> assert false)
acc prg.Ast.program_ctx.ctx_enums)
submodules ctxt.Name_resolution.enums;
ctx_scopes =
Ident.Map.fold
(fun _ def acc ->
match def with
| Name_resolution.TScope (scope, scope_info) ->
ScopeName.Map.add scope scope_info acc
| _ -> acc)
ctxt.Name_resolution.typedefs ScopeName.Map.empty;
ctx_struct_fields = ctxt.Name_resolution.field_idmap;
ctx_topdefs = ctxt.Name_resolution.topdef_types;
ctx_modules =
ModuleName.Map.map (fun s -> s.Ast.program_ctx) submodules;
};
Ast.program_topdefs = TopdefName.Map.empty;
Ast.program_scopes = get_program_scopes ctxt;
Ast.program_modules = submodules;
}
in in
make_ctx ctxt {
ctx_structs = ctxt.structs;
ctx_enums = ctxt.enums;
ctx_scopes =
ModuleName.Map.fold (fun _ -> ctx_scopes)
ctxt.modules
(ctx_scopes ctxt.local ScopeName.Map.empty);
ctx_topdefs = ctxt.topdef_types;
ctx_struct_fields = ctxt.local.field_idmap;
ctx_enum_constrs = ctxt.local.constructor_idmap;
ctx_scope_index =
Ident.Map.filter_map (fun _ -> function
| Name_resolution.TScope (s, _) -> Some s
| _ -> None)
ctxt.local.typedefs;
ctx_modules;
}
in
let desugared =
{
Ast.program_lang = surface.program_lang;
Ast.program_module_name = surface.Surface.Ast.program_module_name;
Ast.program_modules;
Ast.program_ctx;
Ast.program_root = {
Ast.module_scopes = get_scopes ctxt.Name_resolution.local;
Ast.module_topdefs = TopdefName.Map.empty;
};
}
in in
let process_code_block ctxt prgm block = let process_code_block ctxt prgm block =
List.fold_left List.fold_left
@ -1527,29 +1544,6 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) :
(fun prgm child -> process_structure prgm child) (fun prgm child -> process_structure prgm child)
prgm children prgm children
| S.CodeBlock (block, _, _) -> process_code_block ctxt prgm block | S.CodeBlock (block, _, _) -> process_code_block ctxt prgm block
| S.ModuleDef ((name, pos) as mname) -> | S.ModuleDef _ | S.LawInclude _ | S.LawText _ | S.ModuleUse _ -> prgm
let file = Filename.basename (Pos.get_file pos) in
if not File.(equal name (Filename.remove_extension file)) then
Message.raise_spanned_error pos
"Module declared as %a, which does not match the file name %a"
ModuleName.format
(ModuleName.of_string mname)
File.format file
else prgm
| S.LawInclude _ | S.LawText _ | S.ModuleUse _ -> prgm
in
let desugared =
List.fold_left
(fun acc (id, intf) ->
let id = ModuleName.of_string id in
let modul = ModuleName.Map.find id acc.Ast.program_modules in
let modul =
process_code_block (ModuleName.Map.find id ctxt.modules) modul intf
in
{
acc with
program_modules = ModuleName.Map.add id modul acc.program_modules;
})
desugared surface.S.program_modules
in in
List.fold_left process_structure desugared surface.S.program_items List.fold_left process_structure desugared surface.S.program_items

View File

@ -39,7 +39,7 @@ let detect_empty_definitions (p : program) : unit =
defined; did you forget something?" defined; did you forget something?"
ScopeName.format scope_name Ast.ScopeDef.format scope_def_key) ScopeName.format scope_name Ast.ScopeDef.format scope_def_key)
scope.scope_defs) scope.scope_defs)
p.program_scopes p.program_root.module_scopes
(* To detect rules that have the same justification and conclusion, we create a (* To detect rules that have the same justification and conclusion, we create a
set data structure with an appropriate comparison function *) set data structure with an appropriate comparison function *)
@ -97,7 +97,7 @@ let detect_identical_rules (p : program) : unit =
else "definitions")) else "definitions"))
rules_seen) rules_seen)
scope.scope_defs) scope.scope_defs)
p.program_scopes p.program_root.module_scopes
let detect_unused_struct_fields (p : program) : unit = let detect_unused_struct_fields (p : program) : unit =
(* TODO: this analysis should be finer grained: a false negative is if the (* TODO: this analysis should be finer grained: a false negative is if the
@ -111,14 +111,9 @@ let detect_unused_struct_fields (p : program) : unit =
~f:(fun struct_fields_used e -> ~f:(fun struct_fields_used e ->
let rec structs_fields_used_expr e struct_fields_used = let rec structs_fields_used_expr e struct_fields_used =
match Mark.remove e with match Mark.remove e with
| EDStructAccess { name_opt = Some name; e = e_struct; field } -> | EDStructAccess _ -> assert false
let ctx = (* linting must be performed after disambiguation *)
Program.module_ctx p.program_ctx (StructName.path name) | EStructAccess { e = e_struct; field; _ } ->
in
let field =
StructName.Map.find name
(Ident.Map.find field ctx.ctx_struct_fields)
in
StructField.Set.add field StructField.Set.add field
(structs_fields_used_expr e_struct struct_fields_used) (structs_fields_used_expr e_struct struct_fields_used)
| EStruct { name = _; fields } -> | EStruct { name = _; fields } ->
@ -284,7 +279,7 @@ let detect_dead_code (p : program) : unit =
emit_unused_warning ()) emit_unused_warning ())
states) states)
scope.scope_vars) scope.scope_vars)
p.program_scopes p.program_root.module_scopes
let lint_program (p : program) : unit = let lint_program (p : program) : unit =
detect_empty_definitions p; detect_empty_definitions p;

View File

@ -30,10 +30,6 @@ type scope_def_context = {
label_idmap : LabelName.t Ident.Map.t; label_idmap : LabelName.t Ident.Map.t;
} }
type scope_var_or_subscope =
| ScopeVar of ScopeVar.t
| SubScope of SubScopeName.t * ScopeName.t
type scope_context = { type scope_context = {
var_idmap : scope_var_or_subscope Ident.Map.t; var_idmap : scope_var_or_subscope Ident.Map.t;
(** All variables, including scope variables and subscopes *) (** All variables, including scope variables and subscopes *)
@ -67,7 +63,7 @@ type typedef =
| TEnum of EnumName.t | TEnum of EnumName.t
| TScope of ScopeName.t * scope_info (** Implicitly defined output struct *) | TScope of ScopeName.t * scope_info (** Implicitly defined output struct *)
type context = { type module_context = {
path : Uid.Path.t; path : Uid.Path.t;
typedefs : typedef Ident.Map.t; typedefs : typedef Ident.Map.t;
(** Gathers the names of the scopes, structs and enums *) (** Gathers the names of the scopes, structs and enums *)
@ -77,17 +73,24 @@ type context = {
constructor_idmap : EnumConstructor.t EnumName.Map.t Ident.Map.t; constructor_idmap : EnumConstructor.t EnumName.Map.t Ident.Map.t;
(** The names of the enum constructors. Constructor names can be shared (** The names of the enum constructors. Constructor names can be shared
between different enums *) between different enums *)
scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
topdefs : TopdefName.t Ident.Map.t; (** Global definitions *) topdefs : TopdefName.t Ident.Map.t; (** Global definitions *)
used_modules : ModuleName.t Ident.Map.t;
}
(** Context for name resolution, valid within a given module *)
type context = {
scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
topdef_types : typ TopdefName.Map.t; topdef_types : typ TopdefName.Map.t;
structs : struct_context StructName.Map.t; structs : struct_context StructName.Map.t;
(** For each struct, its context *) (** For each struct, its context *)
enums : enum_context EnumName.Map.t; (** For each enum, its context *) enums : enum_context EnumName.Map.t; (** For each enum, its context *)
var_typs : var_sig ScopeVar.Map.t; var_typs : var_sig ScopeVar.Map.t;
(** The signatures of each scope variable declared *) (** The signatures of each scope variable declared *)
modules : context ModuleName.Map.t; modules : module_context ModuleName.Map.t;
local : module_context;
(** Module being currently analysed (at the end: the root module) *)
} }
(** Main context used throughout {!module: Surface.Desugaring} *) (** Global context used throughout {!module: Surface.Desugaring} *)
(** {1 Helpers} *) (** {1 Helpers} *)
@ -114,16 +117,6 @@ let get_var_io (ctxt : context) (uid : ScopeVar.t) :
(ScopeVar.Map.find uid ctxt.var_typs).var_sig_io (ScopeVar.Map.find uid ctxt.var_typs).var_sig_io
let get_scope_context (ctxt : context) (scope : ScopeName.t) : scope_context = let get_scope_context (ctxt : context) (scope : ScopeName.t) : scope_context =
let rec remove_common_prefix curpath scpath =
match curpath, scpath with
| m1 :: cp, m2 :: sp when ModuleName.equal m1 m2 ->
remove_common_prefix cp sp
| _ -> scpath
in
let path = remove_common_prefix ctxt.path (ScopeName.path scope) in
let ctxt =
List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.modules) ctxt path
in
ScopeName.Map.find scope ctxt.scopes ScopeName.Map.find scope ctxt.scopes
(** Get the variable uid inside the scope given in argument *) (** Get the variable uid inside the scope given in argument *)
@ -196,7 +189,7 @@ let is_def_cond (ctxt : context) (def : Ast.ScopeDef.t) : bool =
is_var_cond ctxt x is_var_cond ctxt x
let get_enum ctxt id = let get_enum ctxt id =
match Ident.Map.find (Mark.remove id) ctxt.typedefs with match Ident.Map.find (Mark.remove id) ctxt.local.typedefs with
| TEnum id -> id | TEnum id -> id
| TStruct sid -> | TStruct sid ->
Message.raise_multispanned_error Message.raise_multispanned_error
@ -217,7 +210,7 @@ let get_enum ctxt id =
(Mark.remove id) (Mark.remove id)
let get_struct ctxt id = let get_struct ctxt id =
match Ident.Map.find (Mark.remove id) ctxt.typedefs with match Ident.Map.find (Mark.remove id) ctxt.local.typedefs with
| TStruct id | TScope (_, { out_struct_name = id; _ }) -> id | TStruct id | TScope (_, { out_struct_name = id; _ }) -> id
| TEnum eid -> | TEnum eid ->
Message.raise_multispanned_error Message.raise_multispanned_error
@ -231,7 +224,7 @@ let get_struct ctxt id =
(Mark.remove id) (Mark.remove id)
let get_scope ctxt id = let get_scope ctxt id =
match Ident.Map.find (Mark.remove id) ctxt.typedefs with match Ident.Map.find (Mark.remove id) ctxt.local.typedefs with
| TScope (id, _) -> id | TScope (id, _) -> id
| TEnum eid -> | TEnum eid ->
Message.raise_multispanned_error Message.raise_multispanned_error
@ -251,16 +244,21 @@ let get_scope ctxt id =
Message.raise_spanned_error (Mark.get id) "No scope named %s found" Message.raise_spanned_error (Mark.get id) "No scope named %s found"
(Mark.remove id) (Mark.remove id)
let rec module_ctx ctxt path = let get_modname ctxt (id, pos) =
match path with match Ident.Map.find_opt id ctxt.local.used_modules with
| None ->
Message.raise_spanned_error pos "Module \"@{<blue>%s@}\" not found" id
| Some modname -> modname
let get_module_ctx ctxt id =
let modname = get_modname ctxt id in
{ ctxt with local = ModuleName.Map.find modname ctxt.modules }
let rec module_ctx ctxt path0 =
match path0 with
| [] -> ctxt | [] -> ctxt
| modname :: path -> ( | mod_id :: path ->
let modname = ModuleName.of_string modname in module_ctx (get_module_ctx ctxt mod_id) path
match ModuleName.Map.find_opt modname ctxt.modules with
| None ->
Message.raise_spanned_error (ModuleName.pos modname)
"Module \"%a\" not found" ModuleName.format modname
| Some ctxt -> module_ctx ctxt path)
(** {1 Declarations pass} *) (** {1 Declarations pass} *)
@ -328,7 +326,7 @@ let rec process_base_typ
| Surface.Ast.Boolean -> TLit TBool, typ_pos | Surface.Ast.Boolean -> TLit TBool, typ_pos
| Surface.Ast.Text -> raise_unsupported_feature "text type" typ_pos | Surface.Ast.Text -> raise_unsupported_feature "text type" typ_pos
| Surface.Ast.Named ([], (ident, _pos)) -> ( | Surface.Ast.Named ([], (ident, _pos)) -> (
match Ident.Map.find_opt ident ctxt.typedefs with match Ident.Map.find_opt ident ctxt.local.typedefs with
| Some (TStruct s_uid) -> TStruct s_uid, typ_pos | Some (TStruct s_uid) -> TStruct s_uid, typ_pos
| Some (TEnum e_uid) -> TEnum e_uid, typ_pos | Some (TEnum e_uid) -> TEnum e_uid, typ_pos
| Some (TScope (_, scope_str)) -> | Some (TScope (_, scope_str)) ->
@ -338,15 +336,14 @@ let rec process_base_typ
"Unknown type @{<yellow>\"%s\"@}, not a struct or enum previously \ "Unknown type @{<yellow>\"%s\"@}, not a struct or enum previously \
declared" declared"
ident) ident)
| Surface.Ast.Named (modul :: path, id) -> ( | Surface.Ast.Named ((modul, mpos) :: path, id) -> (
let modul = ModuleName.of_string modul in match Ident.Map.find_opt modul ctxt.local.used_modules with
match ModuleName.Map.find_opt modul ctxt.modules with
| None -> | None ->
Message.raise_spanned_error (ModuleName.pos modul) Message.raise_spanned_error mpos
"This refers to module %a, which was not found" ModuleName.format "This refers to module @{<blue>%s@}, which was not found" modul
modul | Some mname ->
| Some mod_ctxt -> let mod_ctxt = ModuleName.Map.find mname ctxt.modules in
process_base_typ mod_ctxt process_base_typ { ctxt with local = mod_ctxt }
Surface.Ast.(Data (Primitive (Named (path, id))), typ_pos))) Surface.Ast.(Data (Primitive (Named (path, id))), typ_pos)))
(** Process a type (function or not) *) (** Process a type (function or not) *)
@ -449,9 +446,9 @@ let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) :
List.fold_left List.fold_left
(fun ctxt (fdecl, _) -> (fun ctxt (fdecl, _) ->
let f_uid = StructField.fresh fdecl.Surface.Ast.struct_decl_field_name in let f_uid = StructField.fresh fdecl.Surface.Ast.struct_decl_field_name in
let ctxt = let local =
{ {
ctxt with ctxt.local with
field_idmap = field_idmap =
Ident.Map.update Ident.Map.update
(Mark.remove fdecl.Surface.Ast.struct_decl_field_name) (Mark.remove fdecl.Surface.Ast.struct_decl_field_name)
@ -459,26 +456,26 @@ let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) :
match uids with match uids with
| None -> Some (StructName.Map.singleton s_uid f_uid) | None -> Some (StructName.Map.singleton s_uid f_uid)
| Some uids -> Some (StructName.Map.add s_uid f_uid uids)) | Some uids -> Some (StructName.Map.add s_uid f_uid uids))
ctxt.field_idmap; ctxt.local.field_idmap;
} }
in in
{ let ctxt = { ctxt with local } in
ctxt with let structs =
structs = StructName.Map.update s_uid
StructName.Map.update s_uid (fun fields ->
(fun fields -> match fields with
match fields with | None ->
| None -> Some
Some (StructField.Map.singleton f_uid
(StructField.Map.singleton f_uid (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ))
(process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)) | Some fields ->
| Some fields -> Some
Some (StructField.Map.add f_uid
(StructField.Map.add f_uid (process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ)
(process_type ctxt fdecl.Surface.Ast.struct_decl_field_typ) fields))
fields)) ctxt.structs
ctxt.structs; in
}) { ctxt with structs })
ctxt sdecl.struct_decl_fields ctxt sdecl.struct_decl_fields
(** Process an enum declaration *) (** Process an enum declaration *)
@ -494,9 +491,9 @@ let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context
List.fold_left List.fold_left
(fun ctxt (cdecl, cdecl_pos) -> (fun ctxt (cdecl, cdecl_pos) ->
let c_uid = EnumConstructor.fresh cdecl.Surface.Ast.enum_decl_case_name in let c_uid = EnumConstructor.fresh cdecl.Surface.Ast.enum_decl_case_name in
let ctxt = let local =
{ {
ctxt with ctxt.local with
constructor_idmap = constructor_idmap =
Ident.Map.update Ident.Map.update
(Mark.remove cdecl.Surface.Ast.enum_decl_case_name) (Mark.remove cdecl.Surface.Ast.enum_decl_case_name)
@ -504,29 +501,29 @@ let process_enum_decl (ctxt : context) (edecl : Surface.Ast.enum_decl) : context
match uids with match uids with
| None -> Some (EnumName.Map.singleton e_uid c_uid) | None -> Some (EnumName.Map.singleton e_uid c_uid)
| Some uids -> Some (EnumName.Map.add e_uid c_uid uids)) | Some uids -> Some (EnumName.Map.add e_uid c_uid uids))
ctxt.constructor_idmap; ctxt.local.constructor_idmap;
} }
in in
{ let ctxt = { ctxt with local } in
ctxt with let enums =
enums = EnumName.Map.update e_uid
EnumName.Map.update e_uid (fun cases ->
(fun cases -> let typ =
let typ = match cdecl.Surface.Ast.enum_decl_case_typ with
match cdecl.Surface.Ast.enum_decl_case_typ with | None -> TLit TUnit, cdecl_pos
| None -> TLit TUnit, cdecl_pos | Some typ -> process_type ctxt typ
| Some typ -> process_type ctxt typ in
in match cases with
match cases with | None -> Some (EnumConstructor.Map.singleton c_uid typ)
| None -> Some (EnumConstructor.Map.singleton c_uid typ) | Some fields -> Some (EnumConstructor.Map.add c_uid typ fields))
| Some fields -> Some (EnumConstructor.Map.add c_uid typ fields)) ctxt.enums
ctxt.enums; in
}) { ctxt with enums })
ctxt edecl.enum_decl_cases ctxt edecl.enum_decl_cases
let process_topdef ctxt def = let process_topdef ctxt def =
let uid = let uid =
Ident.Map.find (Mark.remove def.Surface.Ast.topdef_name) ctxt.topdefs Ident.Map.find (Mark.remove def.Surface.Ast.topdef_name) ctxt.local.topdefs
in in
{ {
ctxt with ctxt with
@ -605,7 +602,7 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) :
| ScopeVar v -> ( | ScopeVar v -> (
try try
let field = let field =
StructName.Map.find str (Ident.Map.find id ctxt.field_idmap) StructName.Map.find str (Ident.Map.find id ctxt.local.field_idmap)
in in
ScopeVar.Map.add v field svmap ScopeVar.Map.add v field svmap
with StructName.Map.Not_found _ | Ident.Map.Not_found _ -> svmap)) with StructName.Map.Not_found _ | Ident.Map.Not_found _ -> svmap))
@ -620,9 +617,9 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) :
(TScope (TScope
(scope, { in_struct_name; out_struct_name; out_struct_fields })) (scope, { in_struct_name; out_struct_name; out_struct_fields }))
| _ -> assert false) | _ -> assert false)
ctxt.typedefs ctxt.local.typedefs
in in
{ ctxt with typedefs } { ctxt with local = { ctxt.local with typedefs } }
let typedef_info = function let typedef_info = function
| TStruct t -> StructName.get_info t | TStruct t -> StructName.get_info t
@ -648,59 +645,61 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) :
Option.iter Option.iter
(fun use -> (fun use ->
raise_already_defined_error (typedef_info use) name pos "scope") raise_already_defined_error (typedef_info use) name pos "scope")
(Ident.Map.find_opt name ctxt.typedefs); (Ident.Map.find_opt name ctxt.local.typedefs);
let scope_uid = ScopeName.fresh ctxt.path (name, pos) in let scope_uid = ScopeName.fresh ctxt.local.path (name, pos) in
let in_struct_name = StructName.fresh ctxt.path (name ^ "_in", pos) in let in_struct_name = StructName.fresh ctxt.local.path (name ^ "_in", pos) in
let out_struct_name = StructName.fresh ctxt.path (name, pos) in let out_struct_name = StructName.fresh ctxt.local.path (name, pos) in
let typedefs =
Ident.Map.add name
(TScope
( scope_uid,
{
in_struct_name;
out_struct_name;
out_struct_fields = ScopeVar.Map.empty;
} ))
ctxt.local.typedefs
in
let scopes =
ScopeName.Map.add scope_uid
{
var_idmap = Ident.Map.empty;
scope_defs_contexts = Ast.ScopeDef.Map.empty;
sub_scopes = ScopeName.Set.empty;
}
ctxt.scopes
in
{ {
ctxt with ctxt with
typedefs = local = { ctxt.local with typedefs };
Ident.Map.add name scopes;
(TScope
( scope_uid,
{
in_struct_name;
out_struct_name;
out_struct_fields = ScopeVar.Map.empty;
} ))
ctxt.typedefs;
scopes =
ScopeName.Map.add scope_uid
{
var_idmap = Ident.Map.empty;
scope_defs_contexts = Ast.ScopeDef.Map.empty;
sub_scopes = ScopeName.Set.empty;
}
ctxt.scopes;
} }
| StructDecl sdecl -> | StructDecl sdecl ->
let name, pos = sdecl.struct_decl_name in let name, pos = sdecl.struct_decl_name in
Option.iter Option.iter
(fun use -> (fun use ->
raise_already_defined_error (typedef_info use) name pos "struct") raise_already_defined_error (typedef_info use) name pos "struct")
(Ident.Map.find_opt name ctxt.typedefs); (Ident.Map.find_opt name ctxt.local.typedefs);
let s_uid = StructName.fresh ctxt.path sdecl.struct_decl_name in let s_uid = StructName.fresh ctxt.local.path sdecl.struct_decl_name in
{ let typedefs =
ctxt with Ident.Map.add
typedefs = (Mark.remove sdecl.struct_decl_name)
Ident.Map.add (TStruct s_uid) ctxt.local.typedefs;
(Mark.remove sdecl.struct_decl_name) in
(TStruct s_uid) ctxt.typedefs; { ctxt with local = { ctxt.local with typedefs} }
}
| EnumDecl edecl -> | EnumDecl edecl ->
let name, pos = edecl.enum_decl_name in let name, pos = edecl.enum_decl_name in
Option.iter Option.iter
(fun use -> (fun use ->
raise_already_defined_error (typedef_info use) name pos "enum") raise_already_defined_error (typedef_info use) name pos "enum")
(Ident.Map.find_opt name ctxt.typedefs); (Ident.Map.find_opt name ctxt.local.typedefs);
let e_uid = EnumName.fresh ctxt.path edecl.enum_decl_name in let e_uid = EnumName.fresh ctxt.local.path edecl.enum_decl_name in
{ let typedefs =
ctxt with Ident.Map.add
typedefs = (Mark.remove edecl.enum_decl_name)
Ident.Map.add (TEnum e_uid) ctxt.local.typedefs
(Mark.remove edecl.enum_decl_name) in
(TEnum e_uid) ctxt.typedefs; { ctxt with local = { ctxt.local with typedefs} }
}
| ScopeUse _ -> ctxt | ScopeUse _ -> ctxt
| Topdef def -> | Topdef def ->
let name, pos = def.topdef_name in let name, pos = def.topdef_name in
@ -708,9 +707,10 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) :
(fun use -> (fun use ->
raise_already_defined_error (TopdefName.get_info use) name pos raise_already_defined_error (TopdefName.get_info use) name pos
"toplevel definition") "toplevel definition")
(Ident.Map.find_opt name ctxt.topdefs); (Ident.Map.find_opt name ctxt.local.topdefs);
let uid = TopdefName.fresh ctxt.path def.topdef_name in let uid = TopdefName.fresh ctxt.local.path def.topdef_name in
{ ctxt with topdefs = Ident.Map.add name uid ctxt.topdefs } let topdefs = Ident.Map.add name uid ctxt.local.topdefs in
{ ctxt with local = { ctxt.local with topdefs } }
(** Process a code item that is a declaration *) (** Process a code item that is a declaration *)
let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) :
@ -918,7 +918,7 @@ let process_scope_use (ctxt : context) (suse : Surface.Ast.scope_use) : context
match match
Ident.Map.find_opt Ident.Map.find_opt
(Mark.remove suse.Surface.Ast.scope_use_name) (Mark.remove suse.Surface.Ast.scope_use_name)
ctxt.typedefs ctxt.local.typedefs
with with
| Some (TScope (sn, _)) -> sn | Some (TScope (sn, _)) -> sn
| _ -> | _ ->
@ -940,83 +940,90 @@ let process_use_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) :
(** {1 API} *) (** {1 API} *)
let empty_ctxt = let empty_module_ctxt = {
{ path = [];
path = []; typedefs = Ident.Map.empty;
typedefs = Ident.Map.empty; field_idmap = Ident.Map.empty;
scopes = ScopeName.Map.empty; constructor_idmap = Ident.Map.empty;
topdefs = Ident.Map.empty; topdefs = Ident.Map.empty;
topdef_types = TopdefName.Map.empty; used_modules = Ident.Map.empty;
var_typs = ScopeVar.Map.empty; }
structs = StructName.Map.empty;
field_idmap = Ident.Map.empty;
enums = EnumName.Map.empty;
constructor_idmap = Ident.Map.empty;
modules = ModuleName.Map.empty;
}
let import_module modules (name, intf) = let empty_ctxt = {
let mname = ModuleName.of_string name in scopes = ScopeName.Map.empty;
let ctxt = { empty_ctxt with modules; path = [mname] } in topdef_types = TopdefName.Map.empty;
let ctxt = List.fold_left process_name_item ctxt intf in var_typs = ScopeVar.Map.empty;
let ctxt = List.fold_left process_decl_item ctxt intf in structs = StructName.Map.empty;
let ctxt = { ctxt with modules = empty_ctxt.modules } in enums = EnumName.Map.empty;
(* No submodules at the moment, a module may use the ones loaded before it, modules = ModuleName.Map.empty;
but doesn't reexport them *) local = empty_module_ctxt;
ModuleName.Map.add mname ctxt modules }
(** Derive the context from metadata, in one pass over the declarations *) (** Derive the context from metadata, in one pass over the declarations *)
let form_context (prgm : Surface.Ast.program) : context = let form_context (surface, mod_uses) surface_modules : context =
let modules = let rec process_modules ctxt mod_uses =
List.fold_left import_module ModuleName.Map.empty prgm.program_modules (* Recursing on [mod_uses] rather than folding on [modules] ensures a topological traversal. *)
in Ident.Map.fold (fun _alias m ctxt ->
let ctxt = { empty_ctxt with modules } in match ModuleName.Map.find_opt m ctxt.modules with
let rec gather_var_sigs acc modules = | Some _ -> ctxt
(* Scope vars from imported modules need to be accessible directly for | None ->
definitions through submodules *) let intf, mod_uses = ModuleName.Map.find m surface_modules in
ModuleName.Map.fold let ctxt = process_modules ctxt mod_uses in
(fun _modname mctx acc -> let ctxt = { ctxt with
let acc = gather_var_sigs acc mctx.modules in local = { ctxt.local with used_modules = mod_uses;
ScopeVar.Map.union (fun _ _ -> assert false) acc mctx.var_typs) path = [m] } } in
modules acc let ctxt = List.fold_left process_name_item ctxt intf.Surface.Ast.intf_code in
in let ctxt = List.fold_left process_decl_item ctxt intf.Surface.Ast.intf_code in
let ctxt = { ctxt with
{ ctxt with var_typs = gather_var_sigs ScopeVar.Map.empty ctxt.modules } modules = ModuleName.Map.add m ctxt.local ctxt.modules;
local = empty_module_ctxt }
)
mod_uses ctxt
in in
let ctxt = process_modules empty_ctxt mod_uses in
let ctxt = { ctxt with local = { empty_module_ctxt with used_modules = mod_uses } } in
let ctxt = let ctxt =
List.fold_left List.fold_left
(process_law_structure process_name_item) (process_law_structure process_name_item)
ctxt prgm.program_items ctxt surface.Surface.Ast.program_items
in in
let ctxt = let ctxt =
List.fold_left List.fold_left
(process_law_structure process_decl_item) (process_law_structure process_decl_item)
ctxt prgm.program_items ctxt surface.Surface.Ast.program_items
in in
let ctxt = let ctxt =
List.fold_left List.fold_left
(process_law_structure process_use_item) (process_law_structure process_use_item)
ctxt prgm.program_items ctxt surface.Surface.Ast.program_items
in in
let rec gather_all_constrs ctxt = (* Gather struct fields and enum constrs from direct modules: this helps with
(* Gather struct fields and enum constrs from modules: this helps with disambiguation. This is only done towards the root context, because submodules are only interfaces which don't need disambiguation ; and transitive dependencies shouldn't be visible here. *)
disambiguation *) let sub_constructor_idmap, sub_field_idmap =
let modules, constructor_idmap, field_idmap = Ident.Map.fold (fun _ m (cmap, fmap) ->
ModuleName.Map.fold let lctx = ModuleName.Map.find m ctxt.modules in
(fun m ctx (mmap, constrs, fields) -> let cmap =
let ctx = gather_all_constrs ctx in Ident.Map.union
( ModuleName.Map.add m ctx mmap, (fun _ enu1 enu2 -> Some (EnumName.Map.disjoint_union enu1 enu2))
Ident.Map.union cmap lctx.constructor_idmap
(fun _ enu1 enu2 -> in
Some (EnumName.Map.union (fun _ _ -> assert false) enu1 enu2)) let fmap =
constrs ctx.constructor_idmap, Ident.Map.union
Ident.Map.union (fun _ str1 str2 -> Some (StructName.Map.disjoint_union str1 str2))
(fun _ str1 str2 -> fmap lctx.field_idmap
Some (StructName.Map.union (fun _ _ -> assert false) str1 str2)) in
fields ctx.field_idmap )) cmap, fmap)
ctxt.modules mod_uses (Ident.Map.empty, Ident.Map.empty)
(ModuleName.Map.empty, ctxt.constructor_idmap, ctxt.field_idmap)
in
{ ctxt with modules; constructor_idmap; field_idmap }
in in
gather_all_constrs ctxt { ctxt with
local =
{ ctxt.local with
(* In the root context, don't disambiguate on submodules structs/enums when there is a conflict *)
constructor_idmap =
Ident.Map.union (fun _ base _ -> Some base)
ctxt.local.constructor_idmap sub_constructor_idmap;
field_idmap =
Ident.Map.union (fun _ base _ -> Some base)
ctxt.local.field_idmap sub_field_idmap;
}
}

View File

@ -30,10 +30,6 @@ type scope_def_context = {
label_idmap : LabelName.t Ident.Map.t; label_idmap : LabelName.t Ident.Map.t;
} }
type scope_var_or_subscope =
| ScopeVar of ScopeVar.t
| SubScope of SubScopeName.t * ScopeName.t
type scope_context = { type scope_context = {
var_idmap : scope_var_or_subscope Ident.Map.t; var_idmap : scope_var_or_subscope Ident.Map.t;
(** All variables, including scope variables and subscopes *) (** All variables, including scope variables and subscopes *)
@ -67,19 +63,24 @@ type typedef =
| TEnum of EnumName.t | TEnum of EnumName.t
| TScope of ScopeName.t * scope_info (** Implicitly defined output struct *) | TScope of ScopeName.t * scope_info (** Implicitly defined output struct *)
type context = { type module_context = {
path : ModuleName.t list; path : Uid.Path.t;
(** The current path being processed. Used for generating the Uids. *) (** The current path being processed. Used for generating the Uids. *)
typedefs : typedef Ident.Map.t; typedefs : typedef Ident.Map.t;
(** Gathers the names of the scopes, structs and enums *) (** Gathers the names of the scopes, structs and enums *)
field_idmap : StructField.t StructName.Map.t Ident.Map.t; field_idmap : StructField.t StructName.Map.t Ident.Map.t;
(** The names of the struct fields. Names of fields can be shared between (** The names of the struct fields. Names of fields can be shared between
different structs *) different structs. Note that fields from submodules are included here for the root module, because disambiguating there is helpful. *)
constructor_idmap : EnumConstructor.t EnumName.Map.t Ident.Map.t; constructor_idmap : EnumConstructor.t EnumName.Map.t Ident.Map.t;
(** The names of the enum constructors. Constructor names can be shared (** The names of the enum constructors. Constructor names can be shared
between different enums *) between different enums. Note that constructors from its submodules are included here for the root module, because disambiguating there is helpful. *)
scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
topdefs : TopdefName.t Ident.Map.t; (** Global definitions *) topdefs : TopdefName.t Ident.Map.t; (** Global definitions *)
used_modules : ModuleName.t Ident.Map.t; (** Module aliases and the modules they point to *)
}
(** Context for name resolution, valid within a given module *)
type context = {
scopes : scope_context ScopeName.Map.t; (** For each scope, its context *)
topdef_types : typ TopdefName.Map.t; topdef_types : typ TopdefName.Map.t;
(** Types associated with the global definitions *) (** Types associated with the global definitions *)
structs : struct_context StructName.Map.t; structs : struct_context StructName.Map.t;
@ -87,9 +88,12 @@ type context = {
enums : enum_context EnumName.Map.t; (** For each enum, its context *) enums : enum_context EnumName.Map.t; (** For each enum, its context *)
var_typs : var_sig ScopeVar.Map.t; var_typs : var_sig ScopeVar.Map.t;
(** The signatures of each scope variable declared *) (** The signatures of each scope variable declared *)
modules : context ModuleName.Map.t; modules : module_context ModuleName.Map.t;
(** The map to the interfaces of all modules (transitively) used by the program. References are made through [local.used_modules] *)
local : module_context;
(** Local context of the root module corresponding to the program being analysed *)
} }
(** Main context used throughout {!module: Desugared.From_surface} *) (** Global context used throughout {!module: Surface.Desugaring} *)
(** {1 Helpers} *) (** {1 Helpers} *)
@ -101,6 +105,12 @@ val raise_unknown_identifier : string -> Ident.t Mark.pos -> 'a
(** Function to call whenever an identifier used somewhere has not been declared (** Function to call whenever an identifier used somewhere has not been declared
in the program previously *) in the program previously *)
val get_modname : context -> Ident.t Mark.pos -> ModuleName.t
(** Emits a user error if the module name is not found *)
val get_module_ctx : context -> Ident.t Mark.pos -> context
(** Emits a user error if the module name is not found *)
val get_var_typ : context -> ScopeVar.t -> typ val get_var_typ : context -> ScopeVar.t -> typ
(** Gets the type associated to an uid *) (** Gets the type associated to an uid *)
@ -166,5 +176,8 @@ val process_type : context -> Surface.Ast.typ -> typ
(** {1 API} *) (** {1 API} *)
val form_context : Surface.Ast.program -> context val form_context :
Surface.Ast.program * ModuleName.t Ident.Map.t
-> (Surface.Ast.interface * ModuleName.t Ident.Map.t) ModuleName.Map.t
-> context
(** Derive the context from metadata, in one pass over the declarations *) (** Derive the context from metadata, in one pass over the declarations *)

View File

@ -29,66 +29,86 @@ let modname_of_file f =
let load_module_interfaces options includes program = let load_module_interfaces options includes program =
(* Recurse into program modules, looking up files in [using] and loading (* Recurse into program modules, looking up files in [using] and loading
them *) them *)
if program.Surface.Ast.program_used_modules <> [] then
Message.emit_debug "Loading module interfaces...";
let includes = let includes =
includes includes
|> List.map (fun d -> File.Tree.build (options.Cli.path_rewrite d)) |> List.map (fun d -> File.Tree.build (options.Cli.path_rewrite d))
|> List.fold_left File.Tree.union File.Tree.empty |> List.fold_left File.Tree.union File.Tree.empty
in in
let err_req_pos chain = let err_req_pos chain =
List.map (fun m -> Some "Module required from", ModuleName.pos m) chain List.map (fun mpos -> Some "Module required from", mpos) chain
in in
let find_module req_chain m = let find_module req_chain (mname, mpos) =
let fname_base = ModuleName.to_string m in let required_from_file = Pos.get_file mpos in
let required_from_file = Pos.get_file (ModuleName.pos m) in
let includes = let includes =
File.Tree.union includes File.Tree.union includes
(File.Tree.build (File.dirname required_from_file)) (File.Tree.build (File.dirname required_from_file))
in in
match match
List.filter_map List.filter_map
(fun (ext, _) -> File.Tree.lookup includes (fname_base ^ ext)) (fun (ext, _) -> File.Tree.lookup includes (mname ^ ext))
extensions extensions
with with
| [] -> | [] ->
Message.raise_multispanned_error Message.raise_multispanned_error
(err_req_pos (m :: req_chain)) (err_req_pos (mpos :: req_chain))
"Required module not found: %a" ModuleName.format m "Required module not found: @{<blue>%s@}" mname
| [f] -> f | [f] -> f
| ms -> | ms ->
Message.raise_multispanned_error Message.raise_multispanned_error
(err_req_pos (m :: req_chain)) (err_req_pos (mpos :: req_chain))
"Required module %a matches multiple files: %a" ModuleName.format m "Required module @{<blue>%s@} matches multiple files:@;<1 2>%a" mname
(Format.pp_print_list ~pp_sep:Format.pp_print_space File.format) (Format.pp_print_list ~pp_sep:Format.pp_print_space File.format)
ms ms
in in
let load_file f = (* modulename * program * (id -> modulename) *)
let (mname, intf), using = let rec aux req_chain seen uses =
Surface.Parser_driver.load_interface (Cli.FileName f) List.fold_left (fun (seen, use_map) use ->
in let f = find_module req_chain use.Surface.Ast.mod_use_name in
(ModuleName.of_string mname, intf), using match File.Map.find_opt f seen with
| Some (Some (modname, _, _)) ->
seen,
Ident.Map.add
(Mark.remove use.Surface.Ast.mod_use_alias) modname use_map
| Some None ->
Message.raise_multispanned_error
(err_req_pos (Mark.get use.Surface.Ast.mod_use_name :: req_chain))
"Circular module dependency"
| None ->
let intf = Surface.Parser_driver.load_interface (Cli.FileName f) in
let modname = ModuleName.fresh use.Surface.Ast.mod_use_name in
let seen = File.Map.add f None seen in
let seen, sub_use_map =
aux
(Mark.get use.Surface.Ast.mod_use_name :: req_chain)
seen
intf.Surface.Ast.intf_submodules
in
File.Map.add f (Some (modname, intf, sub_use_map)) seen,
Ident.Map.add
(Mark.remove use.Surface.Ast.mod_use_alias) modname use_map)
(seen, Ident.Map.empty) uses
in in
let rec aux req_chain acc modules = let seen =
List.fold_left match program.Surface.Ast.program_module_name with
(fun acc mname -> | Some m ->
let m = ModuleName.of_string mname in let file = Pos.get_file (Mark.get m) in
if List.exists (fun (m1, _) -> ModuleName.equal m m1) acc then acc File.Map.singleton file None
else | None -> File.Map.empty
let f = find_module req_chain m in
let (m', intf), using = load_file f in
if not (ModuleName.equal m m') then
Message.raise_multispanned_error
((Some "Module name declaration", ModuleName.pos m')
:: err_req_pos (m :: req_chain))
"Mismatching module name declaration:";
let acc = (m', intf) :: acc in
aux (m :: req_chain) acc using)
acc modules
in in
let program_modules = let file_module_map, root_uses =
aux [] [] (List.map fst program.Surface.Ast.program_modules) aux [] seen program.Surface.Ast.program_used_modules
|> List.map (fun (m, i) -> (m : ModuleName.t :> string Mark.pos), i)
in in
{ program with Surface.Ast.program_modules } let modules =
File.Map.fold
(fun _ info acc -> match info with
| None -> acc
| Some (mname, intf, use_map) ->
ModuleName.Map.add mname (intf, use_map) acc)
file_module_map ModuleName.Map.empty
in
root_uses, modules
module Passes = struct module Passes = struct
(* Each pass takes only its cli options, then calls upon its dependent passes (* Each pass takes only its cli options, then calls upon its dependent passes
@ -98,23 +118,20 @@ module Passes = struct
Message.emit_debug "@{<bold;magenta>=@} @{<bold>%s@} @{<bold;magenta>=@}" Message.emit_debug "@{<bold;magenta>=@} @{<bold>%s@} @{<bold;magenta>=@}"
(String.uppercase_ascii s) (String.uppercase_ascii s)
let surface options ~includes : Surface.Ast.program = let surface options : Surface.Ast.program =
debug_pass_name "surface"; debug_pass_name "surface";
let prg = let prg =
Surface.Parser_driver.parse_top_level_file options.Cli.input_src Surface.Parser_driver.parse_top_level_file options.Cli.input_src
in in
let prg = Surface.Fill_positions.fill_pos_with_legislative_info prg in Surface.Fill_positions.fill_pos_with_legislative_info prg
load_module_interfaces options includes prg
let desugared options ~includes : let desugared options ~includes :
Desugared.Ast.program * Desugared.Name_resolution.context = Desugared.Ast.program * Desugared.Name_resolution.context =
let prg = surface options ~includes in let prg = surface options in
let mod_uses, modules = load_module_interfaces options includes prg in
debug_pass_name "desugared"; debug_pass_name "desugared";
Message.emit_debug "Name resolution..."; Message.emit_debug "Name resolution...";
let ctx = Desugared.Name_resolution.form_context prg in let ctx = Desugared.Name_resolution.form_context (prg, mod_uses) modules in
(* let scope_uid = get_scope_uid options backend ctx in
* (\* This uid is a Desugared identifier *\)
* let variable_uid = get_variable_uid options backend ctx scope_uid in *)
Message.emit_debug "Desugaring..."; Message.emit_debug "Desugaring...";
let prg = Desugared.From_surface.translate_program ctx prg in let prg = Desugared.From_surface.translate_program ctx prg in
Message.emit_debug "Disambiguating..."; Message.emit_debug "Disambiguating...";
@ -122,16 +139,10 @@ module Passes = struct
Message.emit_debug "Linting..."; Message.emit_debug "Linting...";
Desugared.Linting.lint_program prg; Desugared.Linting.lint_program prg;
prg, ctx prg, ctx
(* Note: we forward the name resolution context throughout in order to locate
uids from strings. Maybe a reduced form should be included directly in
[prg] for that purpose *)
let scopelang options ~includes : let scopelang options ~includes :
untyped Scopelang.Ast.program untyped Scopelang.Ast.program =
* Desugared.Name_resolution.context let prg, _ = desugared options ~includes in
* Desugared.Dependency.ExceptionsDependencies.t
Desugared.Ast.ScopeDef.Map.t =
let prg, ctx = desugared options ~includes in
debug_pass_name "scopelang"; debug_pass_name "scopelang";
let exceptions_graphs = let exceptions_graphs =
Scopelang.From_desugared.build_exceptions_graph prg Scopelang.From_desugared.build_exceptions_graph prg
@ -139,7 +150,7 @@ module Passes = struct
let prg = let prg =
Scopelang.From_desugared.translate_program prg exceptions_graphs Scopelang.From_desugared.translate_program prg exceptions_graphs
in in
prg, ctx, exceptions_graphs prg
let dcalc : let dcalc :
type ty. type ty.
@ -149,10 +160,9 @@ module Passes = struct
check_invariants:bool -> check_invariants:bool ->
typed:ty mark -> typed:ty mark ->
ty Dcalc.Ast.program ty Dcalc.Ast.program
* Desugared.Name_resolution.context
* Scopelang.Dependency.TVertex.t list = * Scopelang.Dependency.TVertex.t list =
fun options ~includes ~optimize ~check_invariants ~typed -> fun options ~includes ~optimize ~check_invariants ~typed ->
let prg, ctx, _ = scopelang options ~includes in let prg = scopelang options ~includes in
debug_pass_name "dcalc"; debug_pass_name "dcalc";
let type_ordering = let type_ordering =
Scopelang.Dependency.check_type_cycles prg.program_ctx.ctx_structs Scopelang.Dependency.check_type_cycles prg.program_ctx.ctx_structs
@ -199,7 +209,7 @@ module Passes = struct
(Message.raise_internal_error "Some Dcalc invariants are invalid") (Message.raise_internal_error "Some Dcalc invariants are invalid")
| _ -> | _ ->
Message.raise_error "--check_invariants cannot be used with --no-typing"); Message.raise_error "--check_invariants cannot be used with --no-typing");
prg, ctx, type_ordering prg, type_ordering
let lcalc let lcalc
(type ty) (type ty)
@ -211,9 +221,8 @@ module Passes = struct
~avoid_exceptions ~avoid_exceptions
~closure_conversion : ~closure_conversion :
untyped Lcalc.Ast.program untyped Lcalc.Ast.program
* Desugared.Name_resolution.context
* Scopelang.Dependency.TVertex.t list = * Scopelang.Dependency.TVertex.t list =
let prg, ctx, type_ordering = let prg, type_ordering =
dcalc options ~includes ~optimize ~check_invariants ~typed dcalc options ~includes ~optimize ~check_invariants ~typed
in in
debug_pass_name "lcalc"; debug_pass_name "lcalc";
@ -265,7 +274,7 @@ module Passes = struct
prg prg
| Custom _ -> assert false) | Custom _ -> assert false)
in in
prg, ctx, type_ordering prg, type_ordering
let scalc let scalc
options options
@ -275,42 +284,34 @@ module Passes = struct
~avoid_exceptions ~avoid_exceptions
~closure_conversion : ~closure_conversion :
Scalc.Ast.program Scalc.Ast.program
* Desugared.Name_resolution.context
* Scopelang.Dependency.TVertex.t list = * Scopelang.Dependency.TVertex.t list =
let prg, ctx, type_ordering = let prg, type_ordering =
lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed lcalc options ~includes ~optimize ~check_invariants ~typed:Expr.typed
~avoid_exceptions ~closure_conversion ~avoid_exceptions ~closure_conversion
in in
debug_pass_name "scalc"; debug_pass_name "scalc";
Scalc.From_lcalc.translate_program prg, ctx, type_ordering Scalc.From_lcalc.translate_program prg, type_ordering
end end
module Commands = struct module Commands = struct
open Cmdliner open Cmdliner
let get_scope_uid (ctxt : Desugared.Name_resolution.context) (scope : string) let get_scope_uid (ctx: decl_ctx) (scope : string): ScopeName.t
= =
match Ident.Map.find_opt scope ctxt.typedefs with if String.contains scope '.' then
| Some (Desugared.Name_resolution.TScope (uid, _)) -> uid Message.raise_error "Only references to the top-level module are allowed";
| _ -> try Ident.Map.find scope ctx.ctx_scope_index with
| Ident.Map.Not_found _ ->
Message.raise_error Message.raise_error
"There is no scope @{<yellow>\"%s\"@} inside the program." scope "There is no scope @{<yellow>\"%s\"@} inside the program." scope
(* TODO: this is very weird but I'm trying to maintain the current behaviour (* TODO: this is very weird but I'm trying to maintain the current behaviour
for now *) for now *)
let get_random_scope_uid (ctxt : Desugared.Name_resolution.context) = let get_random_scope_uid (ctx: decl_ctx): ScopeName.t =
let _, scope = match Ident.Map.choose_opt ctx.ctx_scope_index with
try | Some (_, name) -> name
Shared_ast.Ident.Map.filter_map | None ->
(fun _ -> function Message.raise_error "There isn't any scope inside the program."
| Desugared.Name_resolution.TScope (uid, _) -> Some uid
| _ -> None)
ctxt.typedefs
|> Shared_ast.Ident.Map.choose
with Not_found ->
Message.raise_error "There isn't any scope inside the program."
in
scope
let get_variable_uid let get_variable_uid
(ctxt : Desugared.Name_resolution.context) (ctxt : Desugared.Name_resolution.context)
@ -333,7 +334,7 @@ module Commands = struct
"Variable @{<yellow>\"%s\"@} not found inside scope @{<yellow>\"%a\"@}" "Variable @{<yellow>\"%s\"@} not found inside scope @{<yellow>\"%a\"@}"
variable ScopeName.format scope_uid variable ScopeName.format scope_uid
| Some | Some
(Desugared.Name_resolution.SubScope (subscope_var_name, subscope_name)) (SubScope (subscope_var_name, subscope_name))
-> ( -> (
match second_part with match second_part with
| None -> | None ->
@ -353,7 +354,7 @@ module Commands = struct
Ident.Map.find_opt second_part Ident.Map.find_opt second_part
(ScopeName.Map.find subscope_name ctxt.scopes).var_idmap (ScopeName.Map.find subscope_name ctxt.scopes).var_idmap
with with
| Some (Desugared.Name_resolution.ScopeVar v) -> | Some (ScopeVar v) ->
Desugared.Ast.ScopeDef.SubScopeVar (subscope_var_name, v, Pos.no_pos) Desugared.Ast.ScopeDef.SubScopeVar (subscope_var_name, v, Pos.no_pos)
| _ -> | _ ->
Message.raise_error Message.raise_error
@ -362,7 +363,7 @@ module Commands = struct
arguments." arguments."
second_part SubScopeName.format subscope_var_name ScopeName.format second_part SubScopeName.format subscope_var_name ScopeName.format
scope_uid)) scope_uid))
| Some (Desugared.Name_resolution.ScopeVar v) -> | Some (ScopeVar v) ->
Desugared.Ast.ScopeDef.Var Desugared.Ast.ScopeDef.Var
( v, ( v,
Option.map Option.map
@ -389,7 +390,7 @@ module Commands = struct
~output_file ?ext () ~output_file ?ext ()
let makefile options output = let makefile options output =
let prg = Passes.surface options ~includes:[] in let prg = Passes.surface options in
let backend_extensions_list = [".tex"] in let backend_extensions_list = [".tex"] in
let source_file = Cli.input_src_file options.Cli.input_src in let source_file = Cli.input_src_file options.Cli.input_src in
let output_file, with_output = get_output options ~ext:".d" output in let output_file, with_output = get_output options ~ext:".d" output in
@ -415,7 +416,7 @@ module Commands = struct
Term.(const makefile $ Cli.Flags.Global.options $ Cli.Flags.output) Term.(const makefile $ Cli.Flags.Global.options $ Cli.Flags.output)
let html options output print_only_law wrap_weaved_output = let html options output print_only_law wrap_weaved_output =
let prg = Passes.surface options ~includes:[] in let prg = Passes.surface options in
Message.emit_debug "Weaving literate program into HTML"; Message.emit_debug "Weaving literate program into HTML";
let output_file, with_output = let output_file, with_output =
get_output_format options ~ext:".html" output get_output_format options ~ext:".html" output
@ -444,7 +445,7 @@ module Commands = struct
$ Cli.Flags.wrap_weaved_output) $ Cli.Flags.wrap_weaved_output)
let latex options output print_only_law wrap_weaved_output = let latex options output print_only_law wrap_weaved_output =
let prg = Passes.surface options ~includes:[] in let prg = Passes.surface options in
Message.emit_debug "Weaving literate program into LaTeX"; Message.emit_debug "Weaving literate program into LaTeX";
let output_file, with_output = let output_file, with_output =
get_output_format options ~ext:".tex" output get_output_format options ~ext:".tex" output
@ -473,8 +474,12 @@ module Commands = struct
$ Cli.Flags.wrap_weaved_output) $ Cli.Flags.wrap_weaved_output)
let exceptions options includes ex_scope ex_variable = let exceptions options includes ex_scope ex_variable =
let _, ctxt, exceptions_graphs = Passes.scopelang options ~includes in let prg, ctxt = Passes.desugared options ~includes in
let scope_uid = get_scope_uid ctxt ex_scope in Passes.debug_pass_name "scopelang";
let exceptions_graphs =
Scopelang.From_desugared.build_exceptions_graph prg
in
let scope_uid = get_scope_uid prg.program_ctx ex_scope in
let variable_uid = get_variable_uid ctxt scope_uid ex_variable in let variable_uid = get_variable_uid ctxt scope_uid ex_variable in
Desugared.Print.print_exceptions_graph scope_uid variable_uid Desugared.Print.print_exceptions_graph scope_uid variable_uid
(Desugared.Ast.ScopeDef.Map.find variable_uid exceptions_graphs) (Desugared.Ast.ScopeDef.Map.find variable_uid exceptions_graphs)
@ -496,13 +501,13 @@ module Commands = struct
$ Cli.Flags.ex_variable) $ Cli.Flags.ex_variable)
let scopelang options includes output ex_scope_opt = let scopelang options includes output ex_scope_opt =
let prg, ctx, _ = Passes.scopelang options ~includes in let prg = Passes.scopelang options ~includes in
let _output_file, with_output = get_output_format options output in let _output_file, with_output = get_output_format options output in
with_output with_output
@@ fun fmt -> @@ fun fmt ->
match ex_scope_opt with match ex_scope_opt with
| Some scope -> | Some scope ->
let scope_uid = get_scope_uid ctx scope in let scope_uid = get_scope_uid prg.program_ctx scope in
Scopelang.Print.scope ~debug:options.Cli.debug prg.program_ctx fmt Scopelang.Print.scope ~debug:options.Cli.debug prg.program_ctx fmt
(scope_uid, ScopeName.Map.find scope_uid prg.program_scopes); (scope_uid, ScopeName.Map.find scope_uid prg.program_scopes);
Format.pp_print_newline fmt () Format.pp_print_newline fmt ()
@ -525,7 +530,7 @@ module Commands = struct
$ Cli.Flags.ex_scope_opt) $ Cli.Flags.ex_scope_opt)
let typecheck options includes = let typecheck options includes =
let prg, _, _ = Passes.scopelang options ~includes in let prg = Passes.scopelang options ~includes in
Message.emit_debug "Typechecking..."; Message.emit_debug "Typechecking...";
let _type_ordering = let _type_ordering =
Scopelang.Dependency.check_type_cycles prg.program_ctx.ctx_structs Scopelang.Dependency.check_type_cycles prg.program_ctx.ctx_structs
@ -547,7 +552,7 @@ module Commands = struct
let dcalc typed options includes output optimize ex_scope_opt check_invariants let dcalc typed options includes output optimize ex_scope_opt check_invariants
= =
let prg, ctx, _ = let prg, _ =
Passes.dcalc options ~includes ~optimize ~check_invariants ~typed Passes.dcalc options ~includes ~optimize ~check_invariants ~typed
in in
let _output_file, with_output = get_output_format options output in let _output_file, with_output = get_output_format options output in
@ -555,7 +560,7 @@ module Commands = struct
@@ fun fmt -> @@ fun fmt ->
match ex_scope_opt with match ex_scope_opt with
| Some scope -> | Some scope ->
let scope_uid = get_scope_uid ctx scope in let scope_uid = get_scope_uid prg.decl_ctx scope in
Print.scope ~debug:options.Cli.debug prg.decl_ctx fmt Print.scope ~debug:options.Cli.debug prg.decl_ctx fmt
( scope_uid, ( scope_uid,
Option.get Option.get
@ -568,7 +573,7 @@ module Commands = struct
prg.code_items) ); prg.code_items) );
Format.pp_print_newline fmt () Format.pp_print_newline fmt ()
| None -> | None ->
let scope_uid = get_random_scope_uid ctx in let scope_uid = get_random_scope_uid prg.decl_ctx in
(* TODO: ??? *) (* TODO: ??? *)
let prg_dcalc_expr = Expr.unbox (Program.to_expr prg scope_uid) in let prg_dcalc_expr = Expr.unbox (Program.to_expr prg scope_uid) in
Format.fprintf fmt "%a\n" Format.fprintf fmt "%a\n"
@ -602,14 +607,14 @@ module Commands = struct
ex_scope_opt ex_scope_opt
check_invariants check_invariants
disable_counterexamples = disable_counterexamples =
let prg, ctx, _ = let prg, _ =
Passes.dcalc options ~includes ~optimize ~check_invariants Passes.dcalc options ~includes ~optimize ~check_invariants
~typed:Expr.typed ~typed:Expr.typed
in in
Verification.Globals.setup ~optimize ~disable_counterexamples; Verification.Globals.setup ~optimize ~disable_counterexamples;
let vcs = let vcs =
Verification.Conditions.generate_verification_conditions prg Verification.Conditions.generate_verification_conditions prg
(Option.map (get_scope_uid ctx) ex_scope_opt) (Option.map (get_scope_uid prg.decl_ctx) ex_scope_opt)
in in
Verification.Solver.solve_vc prg.decl_ctx vcs Verification.Solver.solve_vc prg.decl_ctx vcs
@ -654,12 +659,12 @@ module Commands = struct
let interpret_dcalc typed options includes optimize check_invariants ex_scope let interpret_dcalc typed options includes optimize check_invariants ex_scope
= =
let prg, ctx, _ = let prg, _ =
Passes.dcalc options ~includes ~optimize ~check_invariants ~typed Passes.dcalc options ~includes ~optimize ~check_invariants ~typed
in in
Interpreter.load_runtime_modules prg; Interpreter.load_runtime_modules prg;
print_interpretation_results options Interpreter.interpret_program_dcalc prg print_interpretation_results options Interpreter.interpret_program_dcalc prg
(get_scope_uid ctx ex_scope) (get_scope_uid prg.decl_ctx ex_scope)
let interpret_cmd = let interpret_cmd =
let f no_typing = let f no_typing =
@ -691,7 +696,7 @@ module Commands = struct
avoid_exceptions avoid_exceptions
closure_conversion closure_conversion
ex_scope_opt = ex_scope_opt =
let prg, ctx, _ = let prg, _ =
Passes.lcalc options ~includes ~optimize ~check_invariants Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~typed ~avoid_exceptions ~closure_conversion ~typed
in in
@ -700,7 +705,7 @@ module Commands = struct
@@ fun fmt -> @@ fun fmt ->
match ex_scope_opt with match ex_scope_opt with
| Some scope -> | Some scope ->
let scope_uid = get_scope_uid ctx scope in let scope_uid = get_scope_uid prg.decl_ctx scope in
Print.scope ~debug:options.Cli.debug prg.decl_ctx fmt Print.scope ~debug:options.Cli.debug prg.decl_ctx fmt
(scope_uid, Program.get_scope_body prg scope_uid); (scope_uid, Program.get_scope_body prg scope_uid);
Format.pp_print_newline fmt () Format.pp_print_newline fmt ()
@ -739,13 +744,13 @@ module Commands = struct
avoid_exceptions avoid_exceptions
closure_conversion closure_conversion
ex_scope = ex_scope =
let prg, ctx, _ = let prg, _ =
Passes.lcalc options ~includes ~optimize ~check_invariants Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~typed ~avoid_exceptions ~closure_conversion ~typed
in in
Interpreter.load_runtime_modules prg; Interpreter.load_runtime_modules prg;
print_interpretation_results options Interpreter.interpret_program_lcalc prg print_interpretation_results options Interpreter.interpret_program_lcalc prg
(get_scope_uid ctx ex_scope) (get_scope_uid prg.decl_ctx ex_scope)
let interpret_lcalc_cmd = let interpret_lcalc_cmd =
let f no_typing = let f no_typing =
@ -777,7 +782,7 @@ module Commands = struct
check_invariants check_invariants
avoid_exceptions avoid_exceptions
closure_conversion = closure_conversion =
let prg, _, type_ordering = let prg, type_ordering =
Passes.lcalc options ~includes ~optimize ~check_invariants Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~typed:Expr.typed ~avoid_exceptions ~closure_conversion ~typed:Expr.typed
in in
@ -814,7 +819,7 @@ module Commands = struct
avoid_exceptions avoid_exceptions
closure_conversion closure_conversion
ex_scope_opt = ex_scope_opt =
let prg, ctx, _ = let prg, _ =
Passes.scalc options ~includes ~optimize ~check_invariants Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~avoid_exceptions ~closure_conversion
in in
@ -823,7 +828,7 @@ module Commands = struct
@@ fun fmt -> @@ fun fmt ->
match ex_scope_opt with match ex_scope_opt with
| Some scope -> | Some scope ->
let scope_uid = get_scope_uid ctx scope in let scope_uid = get_scope_uid prg.decl_ctx scope in
Scalc.Print.format_item ~debug:options.Cli.debug prg.decl_ctx fmt Scalc.Print.format_item ~debug:options.Cli.debug prg.decl_ctx fmt
(List.find (List.find
(function (function
@ -860,7 +865,7 @@ module Commands = struct
check_invariants check_invariants
avoid_exceptions avoid_exceptions
closure_conversion = closure_conversion =
let prg, _, type_ordering = let prg, type_ordering =
Passes.scalc options ~includes ~optimize ~check_invariants Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~avoid_exceptions ~closure_conversion
in in
@ -889,7 +894,7 @@ module Commands = struct
$ Cli.Flags.closure_conversion) $ Cli.Flags.closure_conversion)
let r options includes output optimize check_invariants closure_conversion = let r options includes output optimize check_invariants closure_conversion =
let prg, _, type_ordering = let prg, type_ordering =
Passes.scalc options ~includes ~optimize ~check_invariants Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions:false ~closure_conversion ~avoid_exceptions:false ~closure_conversion
in in

View File

@ -25,7 +25,8 @@ val main : unit -> unit
Each pass takes only its cli options, then calls upon its dependent passes Each pass takes only its cli options, then calls upon its dependent passes
(forwarding their options as needed) *) (forwarding their options as needed) *)
module Passes : sig module Passes : sig
val surface : Cli.options -> includes:Cli.raw_file list -> Surface.Ast.program
val surface : Cli.options -> Surface.Ast.program
val desugared : val desugared :
Cli.options -> Cli.options ->
@ -36,8 +37,6 @@ module Passes : sig
Cli.options -> Cli.options ->
includes:Cli.raw_file list -> includes:Cli.raw_file list ->
Shared_ast.untyped Scopelang.Ast.program Shared_ast.untyped Scopelang.Ast.program
* Desugared.Name_resolution.context
* Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t
val dcalc : val dcalc :
Cli.options -> Cli.options ->
@ -46,7 +45,6 @@ module Passes : sig
check_invariants:bool -> check_invariants:bool ->
typed:'m Shared_ast.mark -> typed:'m Shared_ast.mark ->
'm Dcalc.Ast.program 'm Dcalc.Ast.program
* Desugared.Name_resolution.context
* Scopelang.Dependency.TVertex.t list * Scopelang.Dependency.TVertex.t list
val lcalc : val lcalc :
@ -58,7 +56,6 @@ module Passes : sig
avoid_exceptions:bool -> avoid_exceptions:bool ->
closure_conversion:bool -> closure_conversion:bool ->
Shared_ast.untyped Lcalc.Ast.program Shared_ast.untyped Lcalc.Ast.program
* Desugared.Name_resolution.context
* Scopelang.Dependency.TVertex.t list * Scopelang.Dependency.TVertex.t list
val scalc : val scalc :
@ -69,7 +66,6 @@ module Passes : sig
avoid_exceptions:bool -> avoid_exceptions:bool ->
closure_conversion:bool -> closure_conversion:bool ->
Scalc.Ast.program Scalc.Ast.program
* Desugared.Name_resolution.context
* Scopelang.Dependency.TVertex.t list * Scopelang.Dependency.TVertex.t list
end end
@ -90,7 +86,7 @@ module Commands : sig
string option * ((Format.formatter -> 'a) -> 'a) string option * ((Format.formatter -> 'a) -> 'a)
val get_scope_uid : val get_scope_uid :
Desugared.Name_resolution.context -> string -> Shared_ast.ScopeName.t Shared_ast.decl_ctx -> string -> Shared_ast.ScopeName.t
val get_variable_uid : val get_variable_uid :
Desugared.Name_resolution.context -> Desugared.Name_resolution.context ->

View File

@ -405,26 +405,20 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box =
let replace_fun_typs t = let replace_fun_typs t =
if type_contains_arrow t then Mark.copy t TAny else t if type_contains_arrow t then Mark.copy t TAny else t
in in
let rec convert_ctx ctx = {
{ p.decl_ctx with
ctx_struct_fields = ctx.ctx_struct_fields; ctx_structs =
ctx_modules = ModuleName.Map.map convert_ctx ctx.ctx_modules; StructName.Map.map
ctx_structs = (StructField.Map.map replace_fun_typs)
StructName.Map.map p.decl_ctx.ctx_structs;
(StructField.Map.map replace_fun_typs) ctx_enums =
ctx.ctx_structs; EnumName.Map.map
ctx_enums = (EnumConstructor.Map.map replace_fun_typs)
EnumName.Map.map p.decl_ctx.ctx_enums;
(EnumConstructor.Map.map replace_fun_typs) (* Toplevel definitions may not contain scope calls or take functions as
ctx.ctx_enums; arguments at the moment, which ensures that their interfaces aren't
ctx_scopes = ctx.ctx_scopes; changed by the conversion *)
ctx_topdefs = ctx.ctx_topdefs; }
(* Toplevel definitions may not contain scope calls or take functions as
arguments at the moment, which ensures that their interfaces aren't
changed by the conversion *)
}
in
convert_ctx p.decl_ctx
in in
Bindlib.box_apply Bindlib.box_apply
(fun new_code_items -> (fun new_code_items ->

View File

@ -439,7 +439,7 @@ let run
options = options =
if not options.Cli.trace then if not options.Cli.trace then
Message.raise_error "This plugin requires the --trace flag."; Message.raise_error "This plugin requires the --trace flag.";
let prg, _, type_ordering = let prg, type_ordering =
Driver.Passes.lcalc options ~includes ~optimize ~check_invariants Driver.Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~typed:Expr.typed ~avoid_exceptions ~closure_conversion ~typed:Expr.typed
in in

View File

@ -1387,12 +1387,12 @@ let options =
$ base_src_url) $ base_src_url)
let run includes optimize ex_scope explain_options global_options = let run includes optimize ex_scope explain_options global_options =
let prg, ctx, _ = let prg, _ =
Driver.Passes.dcalc global_options ~includes ~optimize Driver.Passes.dcalc global_options ~includes ~optimize
~check_invariants:false ~typed:Expr.typed ~check_invariants:false ~typed:Expr.typed
in in
Interpreter.load_runtime_modules prg; Interpreter.load_runtime_modules prg;
let scope = Driver.Commands.get_scope_uid ctx ex_scope in let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in
(* let result_expr, env = interpret_program prg scope in *) (* let result_expr, env = interpret_program prg scope in *)
let g, base_vars, env = program_to_graph explain_options prg scope in let g, base_vars, env = program_to_graph explain_options prg scope in
log "Base variables detected: @[<hov>%a@]" log "Base variables detected: @[<hov>%a@]"

View File

@ -214,7 +214,7 @@ let run
closure_conversion closure_conversion
ex_scope ex_scope
options = options =
let prg, ctx, _ = let prg, _ =
Driver.Passes.lcalc options ~includes ~optimize ~check_invariants Driver.Passes.lcalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~typed:Expr.typed ~avoid_exceptions ~closure_conversion ~typed:Expr.typed
in in
@ -223,7 +223,7 @@ let run
in in
with_output with_output
@@ fun fmt -> @@ fun fmt ->
let scope_uid = Driver.Commands.get_scope_uid ctx ex_scope in let scope_uid = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in
Message.emit_debug Message.emit_debug
"Writing JSON schema corresponding to the scope '%a' to the file %s..." "Writing JSON schema corresponding to the scope '%a' to the file %s..."
ScopeName.format scope_uid ScopeName.format scope_uid

View File

@ -259,12 +259,12 @@ let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) :
(* -- Plugin registration -- *) (* -- Plugin registration -- *)
let run includes optimize check_invariants ex_scope options = let run includes optimize check_invariants ex_scope options =
let prg, ctx, _ = let prg, _ =
Driver.Passes.dcalc options ~includes ~optimize ~check_invariants Driver.Passes.dcalc options ~includes ~optimize ~check_invariants
~typed:Expr.typed ~typed:Expr.typed
in in
Interpreter.load_runtime_modules prg; Interpreter.load_runtime_modules prg;
let scope = Driver.Commands.get_scope_uid ctx ex_scope in let scope = Driver.Commands.get_scope_uid prg.decl_ctx ex_scope in
let result_expr, _env = interpret_program prg scope in let result_expr, _env = interpret_program prg scope in
let fmt = Format.std_formatter in let fmt = Format.std_formatter in
Expr.format fmt result_expr Expr.format fmt result_expr

View File

@ -31,7 +31,7 @@ let run
closure_conversion closure_conversion
options = options =
let open Driver.Commands in let open Driver.Commands in
let prg, _, type_ordering = let prg, type_ordering =
Driver.Passes.scalc options ~includes ~optimize ~check_invariants Driver.Passes.scalc options ~includes ~optimize ~check_invariants
~avoid_exceptions ~closure_conversion ~avoid_exceptions ~closure_conversion
in in

View File

@ -58,10 +58,10 @@ type 'm scope_decl = {
type 'm program = { type 'm program = {
program_module_name : ModuleName.t option; program_module_name : ModuleName.t option;
program_ctx : decl_ctx;
program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t;
program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t;
program_topdefs : ('m expr * typ) TopdefName.Map.t; program_topdefs : ('m expr * typ) TopdefName.Map.t;
program_modules : nil program ModuleName.Map.t;
program_ctx : decl_ctx;
program_lang : Cli.backend_lang; program_lang : Cli.backend_lang;
} }
@ -77,42 +77,34 @@ let type_rule decl_ctx env = function
let pos = Expr.mark_pos m in let pos = Expr.mark_pos m in
Call (sc_name, ssc_name, Typed { pos; ty = Mark.add pos TAny }) Call (sc_name, ssc_name, Typed { pos; ty = Mark.add pos TAny })
let type_program (prg : 'm program) : typed program = let type_program (type m) (prg : m program) : typed program =
(* Caution: this environment building code is very similar to that in (* Caution: this environment building code is very similar to that in
desugared/disambiguate.ml. Any edits should probably be reflected. *) desugared/disambiguate.ml. Any edits should probably be reflected. *)
let base_typing_env prg = let env = Typing.Env.empty prg.program_ctx in
let env = Typing.Env.empty prg.program_ctx in let env =
let env = TopdefName.Map.fold
TopdefName.Map.fold (fun name ty env -> Typing.Env.add_toplevel_var name ty env)
(fun name ty env -> Typing.Env.add_toplevel_var name ty env) prg.program_ctx.ctx_topdefs env
prg.program_ctx.ctx_topdefs env
in
let env =
ScopeName.Map.fold
(fun scope_name scope_decl env ->
let sg = (Mark.remove scope_decl).scope_sig in
let vars =
ScopeVar.Map.map (fun { svar_out_ty; _ } -> svar_out_ty) sg
in
let in_vars =
ScopeVar.Map.map (fun { svar_in_ty; _ } -> svar_in_ty) sg
in
Typing.Env.add_scope scope_name ~vars ~in_vars env)
prg.program_scopes env
in
env
in
let rec build_typing_env prg =
ModuleName.Map.fold
(fun modname prg ->
Typing.Env.add_module modname ~module_env:(build_typing_env prg))
prg.program_modules (base_typing_env prg)
in in
let env = let env =
ModuleName.Map.fold ScopeName.Map.fold
(fun modname prg -> (fun scope_name _info env ->
Typing.Env.add_module modname ~module_env:(build_typing_env prg)) let scope_sig =
prg.program_modules (base_typing_env prg) match ScopeName.path scope_name with
| [] -> (Mark.remove (ScopeName.Map.find scope_name prg.program_scopes)).scope_sig
| p ->
let m = List.hd (List.rev p) in
let scope = ScopeName.Map.find scope_name (ModuleName.Map.find m prg.program_modules) in
(Mark.remove scope).scope_sig
in
let vars =
ScopeVar.Map.map (fun { svar_out_ty; _ } -> svar_out_ty) scope_sig
in
let in_vars =
ScopeVar.Map.map (fun { svar_in_ty; _ } -> svar_in_ty) scope_sig
in
Typing.Env.add_scope scope_name ~vars ~in_vars env)
prg.program_ctx.ctx_scopes env
in in
let program_topdefs = let program_topdefs =
TopdefName.Map.map TopdefName.Map.map

View File

@ -51,14 +51,13 @@ type 'm scope_decl = {
type 'm program = { type 'm program = {
program_module_name : ModuleName.t option; program_module_name : ModuleName.t option;
program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; program_ctx : decl_ctx;
program_topdefs : ('m expr * typ) TopdefName.Map.t; program_modules : nil scope_decl Mark.pos ScopeName.Map.t ModuleName.Map.t;
program_modules : nil program ModuleName.Map.t;
(* Using [nil] here ensure that program interfaces don't contain any (* Using [nil] here ensure that program interfaces don't contain any
expressions. They won't contain any rules or topdefs, but will still have expressions. They won't contain any rules or topdefs, but will still have
the scope signatures needed to respect the call convention *) the scope signatures needed to respect the call convention *)
program_ctx : decl_ctx; program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t;
program_topdefs : ('m expr * typ) TopdefName.Map.t;
program_lang : Cli.backend_lang; program_lang : Cli.backend_lang;
} }
val type_program : 'm program -> typed program val type_program : 'm program -> typed program

View File

@ -31,7 +31,6 @@ type ctx = {
scope_var_mapping : target_scope_vars ScopeVar.Map.t; scope_var_mapping : target_scope_vars ScopeVar.Map.t;
reentrant_vars : typ ScopeVar.Map.t; reentrant_vars : typ ScopeVar.Map.t;
var_mapping : (D.expr, untyped Ast.expr Var.t) Var.Map.t; var_mapping : (D.expr, untyped Ast.expr Var.t) Var.Map.t;
modules : ctx ModuleName.Map.t;
} }
let tag_with_log_entry let tag_with_log_entry
@ -61,11 +60,6 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed =
| ELocation (SubScopeVar { scope; alias; var }) -> | ELocation (SubScopeVar { scope; alias; var }) ->
(* When referring to a subscope variable in an expression, we are referring (* When referring to a subscope variable in an expression, we are referring
to the output, hence we take the last state. *) to the output, hence we take the last state. *)
let ctx =
List.fold_left
(fun ctx m -> ModuleName.Map.find m ctx.modules)
ctx (ScopeName.path scope)
in
let var = let var =
match ScopeVar.Map.find (Mark.remove var) ctx.scope_var_mapping with match ScopeVar.Map.find (Mark.remove var) ctx.scope_var_mapping with
| WholeVar new_s_var -> Mark.copy var new_s_var | WholeVar new_s_var -> Mark.copy var new_s_var
@ -97,27 +91,8 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed =
}) })
m m
| ELocation (ToplevelVar v) -> Expr.elocation (ToplevelVar v) m | ELocation (ToplevelVar v) -> Expr.elocation (ToplevelVar v) m
| EDStructAccess { name_opt = None; _ } -> | EDStructAccess _ -> assert false
(* Note: this could only happen if disambiguation was disabled. If we want (* This shouldn't appear in desugared after disambiguation *)
to support it, we should still allow this case when the field has only
one possible matching structure *)
Message.raise_spanned_error (Expr.mark_pos m)
"Ambiguous structure field access"
| EDStructAccess { e; field; name_opt = Some name } ->
let e' = translate_expr ctx e in
let field =
let decl_ctx = Program.module_ctx ctx.decl_ctx (StructName.path name) in
try
StructName.Map.find name
(Ident.Map.find field decl_ctx.ctx_struct_fields)
with StructName.Map.Not_found _ | Ident.Map.Not_found _ ->
(* Should not happen after disambiguation *)
Message.raise_spanned_error (Expr.mark_pos m)
"Field @{<yellow>\"%s\"@} does not belong to structure \
@{<yellow>\"%a\"@}"
field StructName.format name
in
Expr.estructaccess ~e:e' ~field ~name m
| EScopeCall { scope; args } -> | EScopeCall { scope; args } ->
Expr.escopecall ~scope Expr.escopecall ~scope
~args: ~args:
@ -168,7 +143,7 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed =
| op, `Reversed -> | op, `Reversed ->
Expr.eapp (Expr.eop op (List.rev tys) m1) (List.rev args) m) Expr.eapp (Expr.eop op (List.rev tys) m1) (List.rev args) m)
| EOp _ -> assert false (* Only allowed within [EApp] *) | EOp _ -> assert false (* Only allowed within [EApp] *)
| ( EStruct _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | ELit _ | ( EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | ELit _
| EApp _ | EDefault _ | EPureDefault _ | EIfThenElse _ | EArray _ | EApp _ | EDefault _ | EPureDefault _ | EIfThenElse _ | EArray _
| EEmptyError | EErrorOnEmpty _ ) as e -> | EEmptyError | EErrorOnEmpty _ ) as e ->
Expr.map ~f:(translate_expr ctx) (e, m) Expr.map ~f:(translate_expr ctx) (e, m)
@ -300,8 +275,7 @@ let scope_to_exception_graphs (scope : D.scope) :
List.fold_left List.fold_left
(fun exceptions_graphs scope_def_key -> (fun exceptions_graphs scope_def_key ->
let new_exceptions_graphs = rule_to_exception_graph scope scope_def_key in let new_exceptions_graphs = rule_to_exception_graph scope scope_def_key in
D.ScopeDef.Map.union D.ScopeDef.Map.disjoint_union
(fun _ _ _ -> assert false (* there should not be key conflicts *))
new_exceptions_graphs exceptions_graphs) new_exceptions_graphs exceptions_graphs)
D.ScopeDef.Map.empty scope_ordering D.ScopeDef.Map.empty scope_ordering
@ -310,10 +284,9 @@ let build_exceptions_graph (pgrm : D.program) :
ScopeName.Map.fold ScopeName.Map.fold
(fun _ scope exceptions_graph -> (fun _ scope exceptions_graph ->
let new_exceptions_graphs = scope_to_exception_graphs scope in let new_exceptions_graphs = scope_to_exception_graphs scope in
D.ScopeDef.Map.union D.ScopeDef.Map.disjoint_union
(fun _ _ _ -> assert false (* key conflicts should not happen*))
new_exceptions_graphs exceptions_graph) new_exceptions_graphs exceptions_graph)
pgrm.program_scopes D.ScopeDef.Map.empty pgrm.program_root.module_scopes D.ScopeDef.Map.empty
(** Transforms a flat list of rules into a tree, taking into account the (** Transforms a flat list of rules into a tree, taking into account the
priorities declared between rules *) priorities declared between rules *)
@ -789,26 +762,31 @@ let translate_program
(* First we give mappings to all the locations between Desugared and This (* First we give mappings to all the locations between Desugared and This
involves creating a new Scopelang scope variable for every state of a involves creating a new Scopelang scope variable for every state of a
Desugared variable. *) Desugared variable. *)
let rec make_ctx desugared = let ctx =
let modules = ModuleName.Map.map make_ctx desugared.D.program_modules in let ctx =
(* Todo: since we rename all scope vars at this point, it would be better to {
have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *) scope_var_mapping = ScopeVar.Map.empty;
ScopeName.Map.fold var_mapping = Var.Map.empty;
(fun _scope scope_decl ctx -> reentrant_vars = ScopeVar.Map.empty;
ScopeVar.Map.fold decl_ctx = desugared.program_ctx;
(fun scope_var (states : D.var_or_states) ctx -> }
let var_name, var_pos = ScopeVar.get_info scope_var in in
let new_var = let add_scope_mappings modul ctx =
match states with ScopeName.Map.fold (fun _ scdef ctx ->
| D.WholeVar -> WholeVar (ScopeVar.fresh (var_name, var_pos)) ScopeVar.Map.fold
| States states -> (fun scope_var (states : D.var_or_states) ctx ->
let var_prefix = var_name ^ "_" in let var_name, var_pos = ScopeVar.get_info scope_var in
let state_var state = let new_var =
ScopeVar.fresh match states with
(Mark.map (( ^ ) var_prefix) (StateName.get_info state)) | D.WholeVar -> WholeVar (ScopeVar.fresh (var_name, var_pos))
in | States states ->
States (List.map (fun state -> state, state_var state) states) let var_prefix = var_name ^ "_" in
in let state_var state =
ScopeVar.fresh
(Mark.map (( ^ ) var_prefix) (StateName.get_info state))
in
States (List.map (fun state -> state, state_var state) states)
in
let reentrant = let reentrant =
let state = let state =
match states with match states with
@ -819,7 +797,7 @@ let translate_program
match match
D.ScopeDef.Map.find_opt D.ScopeDef.Map.find_opt
(Var (scope_var, state)) (Var (scope_var, state))
scope_decl.D.scope_defs scdef.D.scope_defs
with with
| Some | Some
{ {
@ -830,96 +808,53 @@ let translate_program
Some scope_def_typ Some scope_def_typ
| _ -> None | _ -> None
in in
{ {
ctx with ctx with
scope_var_mapping = scope_var_mapping =
ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping; ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping;
reentrant_vars = reentrant_vars =
Option.fold reentrant Option.fold reentrant
~some:(fun ty -> ~some:(fun ty ->
ScopeVar.Map.add scope_var ty ctx.reentrant_vars) ScopeVar.Map.add scope_var ty ctx.reentrant_vars)
~none:ctx.reentrant_vars; ~none:ctx.reentrant_vars;
}) })
scope_decl.D.scope_vars ctx) scdef.D.scope_vars ctx)
desugared.D.program_scopes modul.D.module_scopes ctx
{
scope_var_mapping = ScopeVar.Map.empty;
var_mapping = Var.Map.empty;
reentrant_vars = ScopeVar.Map.empty;
decl_ctx = desugared.program_ctx;
modules;
}
in
let ctx = make_ctx desugared in
let rec gather_scope_vars acc modules =
ModuleName.Map.fold
(fun _modname mctx (vmap, reentr) ->
let vmap, reentr = gather_scope_vars (vmap, reentr) mctx.modules in
( ScopeVar.Map.union
(fun _ _ -> assert false)
vmap mctx.scope_var_mapping,
ScopeVar.Map.union
(fun _ _ -> assert false)
reentr mctx.reentrant_vars ))
modules acc
in
let ctx =
let scope_var_mapping, reentrant_vars =
gather_scope_vars (ctx.scope_var_mapping, ctx.reentrant_vars) ctx.modules
in in
{ ctx with scope_var_mapping; reentrant_vars } (* Todo: since we rename all scope vars at this point, it would be better to
have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *)
ModuleName.Map.fold (fun _ m ctx -> add_scope_mappings m ctx)
desugared.D.program_modules
(add_scope_mappings (desugared.D.program_root) ctx)
in in
let rec process_decl_ctx ctx decl_ctx = let decl_ctx =
let ctx_scopes = let ctx_scopes =
ScopeName.Map.map ScopeName.Map.map
(fun out_str -> (fun out_str ->
let out_struct_fields = let out_struct_fields =
ScopeVar.Map.fold ScopeVar.Map.fold
(fun var fld out_map -> (fun var fld out_map ->
let var' = let var' =
match ScopeVar.Map.find var ctx.scope_var_mapping with match ScopeVar.Map.find var ctx.scope_var_mapping with
| WholeVar v -> v | WholeVar v -> v
| States l -> snd (List.hd (List.rev l)) | States l -> snd (List.hd (List.rev l))
in in
ScopeVar.Map.add var' fld out_map) ScopeVar.Map.add var' fld out_map)
out_str.out_struct_fields ScopeVar.Map.empty out_str.out_struct_fields ScopeVar.Map.empty
in in
{ out_str with out_struct_fields }) { out_str with out_struct_fields })
decl_ctx.ctx_scopes desugared.program_ctx.ctx_scopes
in in
{ { desugared.program_ctx with ctx_scopes }
decl_ctx with
ctx_modules =
ModuleName.Map.mapi
(fun modname decl_ctx ->
let ctx = ModuleName.Map.find modname ctx.modules in
process_decl_ctx ctx decl_ctx)
decl_ctx.ctx_modules;
ctx_scopes;
}
in in
let rec process_modules program_ctx desugared = let ctx = { ctx with decl_ctx }in
ModuleName.Map.mapi let program_modules =
(fun modname m_desugared -> ModuleName.Map.map (fun m ->
let ctx = ModuleName.Map.find modname ctx.modules in ScopeName.Map.map
{ (translate_scope_interface ctx)
Ast.program_module_name = Some modname; m.D.module_scopes)
Ast.program_topdefs = TopdefName.Map.empty;
program_scopes =
ScopeName.Map.map
(translate_scope_interface ctx)
m_desugared.D.program_scopes;
program_ctx = ModuleName.Map.find modname program_ctx.ctx_modules;
program_modules =
process_modules
(ModuleName.Map.find modname program_ctx.ctx_modules)
m_desugared;
Ast.program_lang = desugared.program_lang;
})
desugared.D.program_modules desugared.D.program_modules
in in
let program_ctx = process_decl_ctx ctx desugared.D.program_ctx in
let program_modules = process_modules program_ctx desugared in
let program_topdefs = let program_topdefs =
TopdefName.Map.mapi TopdefName.Map.mapi
(fun id -> function (fun id -> function
@ -927,18 +862,18 @@ let translate_program
| None, (_, pos) -> | None, (_, pos) ->
Message.raise_spanned_error pos "No definition found for %a" Message.raise_spanned_error pos "No definition found for %a"
TopdefName.format id) TopdefName.format id)
desugared.program_topdefs desugared.program_root.module_topdefs
in in
let program_scopes = let program_scopes =
ScopeName.Map.map ScopeName.Map.map
(translate_scope ctx exc_graphs) (translate_scope ctx exc_graphs)
desugared.D.program_scopes desugared.D.program_root.module_scopes
in in
{ {
Ast.program_module_name = desugared.D.program_module_name; Ast.program_module_name = Option.map ModuleName.fresh desugared.D.program_module_name;
Ast.program_topdefs; Ast.program_topdefs;
Ast.program_scopes; Ast.program_scopes;
Ast.program_ctx; Ast.program_ctx = ctx.decl_ctx;
Ast.program_modules; Ast.program_modules;
Ast.program_lang = desugared.program_lang; Ast.program_lang = desugared.program_lang;
} }

View File

@ -102,6 +102,10 @@ module SubScopeName =
end) end)
() ()
type scope_var_or_subscope =
| ScopeVar of ScopeVar.t
| SubScope of SubScopeName.t * ScopeName.t
module StateName = module StateName =
Uid.Gen Uid.Gen
(struct (struct
@ -135,7 +139,6 @@ type desugared =
; overloaded : yes ; overloaded : yes
; resolved : no ; resolved : no
; syntacticNames : yes ; syntacticNames : yes
; resolvedNames : no
; scopeVarStates : yes ; scopeVarStates : yes
; scopeVarSimpl : no ; scopeVarSimpl : no
; explicitScopes : yes ; explicitScopes : yes
@ -143,6 +146,9 @@ type desugared =
; defaultTerms : yes ; defaultTerms : yes
; exceptions : no ; exceptions : no
; custom : no > ; custom : no >
(* Technically, desugared before name resolution has [syntacticNames: yes; resolvedNames: no], and after name resolution has the opposite; but the disambiguation being done by the typer, we don't encode this invariant at the type level.
Indeed, unfortunately, we cannot express the [<resolvedNames: _; 'a> -> <resolvedNames: yes; 'a>] that would be needed for the typing function. *)
type scopelang = type scopelang =
< monomorphic : yes < monomorphic : yes
@ -150,7 +156,6 @@ type scopelang =
; overloaded : no ; overloaded : no
; resolved : yes ; resolved : yes
; syntacticNames : no ; syntacticNames : no
; resolvedNames : yes
; scopeVarStates : no ; scopeVarStates : no
; scopeVarSimpl : yes ; scopeVarSimpl : yes
; explicitScopes : yes ; explicitScopes : yes
@ -165,7 +170,6 @@ type dcalc =
; overloaded : no ; overloaded : no
; resolved : yes ; resolved : yes
; syntacticNames : no ; syntacticNames : no
; resolvedNames : yes
; scopeVarStates : no ; scopeVarStates : no
; scopeVarSimpl : no ; scopeVarSimpl : no
; explicitScopes : no ; explicitScopes : no
@ -180,7 +184,6 @@ type lcalc =
; overloaded : no ; overloaded : no
; resolved : yes ; resolved : yes
; syntacticNames : no ; syntacticNames : no
; resolvedNames : yes
; scopeVarStates : no ; scopeVarStates : no
; scopeVarSimpl : no ; scopeVarSimpl : no
; explicitScopes : no ; explicitScopes : no
@ -199,7 +202,6 @@ type dcalc_lcalc_features =
; overloaded : no ; overloaded : no
; resolved : yes ; resolved : yes
; syntacticNames : no ; syntacticNames : no
; resolvedNames : yes
; scopeVarStates : no ; scopeVarStates : no
; scopeVarSimpl : no ; scopeVarSimpl : no
; explicitScopes : no ; explicitScopes : no
@ -535,8 +537,8 @@ and ('a, 'b, 'm) base_gexpr =
e : ('a, 'm) gexpr; e : ('a, 'm) gexpr;
field : StructField.t; field : StructField.t;
} }
-> ('a, < resolvedNames : yes ; .. >, 'm) base_gexpr -> ('a, < .. >, 'm) base_gexpr
(** Resolved struct/enums, after [desugared] *) (** Resolved struct/enums, after name resolution in [desugared] *)
(* Lambda-like *) (* Lambda-like *)
| EExternal : { | EExternal : {
name : external_ref Mark.pos; name : external_ref Mark.pos;
@ -651,8 +653,8 @@ type 'e code_item =
| ScopeDef of ScopeName.t * 'e scope_body | ScopeDef of ScopeName.t * 'e scope_body
| Topdef of TopdefName.t * typ * 'e | Topdef of TopdefName.t * typ * 'e
(* A chained list, but with a binder for each element into the next: [x := let a (** A chained list, but with a binder for each element into the next: [x := let a
= e1 in e2] is thus [Cons (e1, {a. Cons (e2, {x. Nil})})] *) = e1 in e2] is thus [Cons (e1, {a. Cons (e2, {x. Nil})})] *)
type 'e code_item_list = type 'e code_item_list =
| Nil | Nil
| Cons of 'e code_item * ('e, 'e code_item_list) binder | Cons of 'e code_item * ('e, 'e code_item_list) binder
@ -666,14 +668,20 @@ type scope_info = {
out_struct_fields : StructField.t ScopeVar.Map.t; out_struct_fields : StructField.t ScopeVar.Map.t;
} }
type module_tree = M of module_tree ModuleName.Map.t [@@caml.unboxed]
(** In practice, this is a DAG: beware of repeated names *)
type decl_ctx = { type decl_ctx = {
ctx_enums : enum_ctx; ctx_enums : enum_ctx;
ctx_structs : struct_ctx; ctx_structs : struct_ctx;
ctx_struct_fields : StructField.t StructName.Map.t Ident.Map.t;
(** needed for disambiguation (desugared -> scope) *)
ctx_scopes : scope_info ScopeName.Map.t; ctx_scopes : scope_info ScopeName.Map.t;
ctx_topdefs : typ TopdefName.Map.t; ctx_topdefs : typ TopdefName.Map.t;
ctx_modules : decl_ctx ModuleName.Map.t; ctx_struct_fields : StructField.t StructName.Map.t Ident.Map.t;
(** needed for disambiguation (desugared -> scope) *)
ctx_enum_constrs : EnumConstructor.t EnumName.Map.t Ident.Map.t;
ctx_scope_index : ScopeName.t Ident.Map.t;
(** only used to lookup scopes (in the root module) specified from the cli *)
ctx_modules : module_tree;
} }
type 'e program = { type 'e program = {

View File

@ -134,7 +134,7 @@ val estructaccess :
field:StructField.t -> field:StructField.t ->
e:('a, 'm) boxed_gexpr -> e:('a, 'm) boxed_gexpr ->
'm mark -> 'm mark ->
((< resolvedNames : yes ; .. > as 'a), 'm) boxed_gexpr ('a any, 'm) boxed_gexpr
val einj : val einj :
name:EnumName.t -> name:EnumName.t ->

View File

@ -571,7 +571,6 @@ let rec evaluate_expr :
in in
let ty = let ty =
try try
let ctx = Program.module_ctx ctx path in
match Mark.remove name with match Mark.remove name with
| External_value name -> TopdefName.Map.find name ctx.ctx_topdefs | External_value name -> TopdefName.Map.find name ctx.ctx_topdefs
| External_scope name -> | External_scope name ->
@ -986,12 +985,13 @@ let load_runtime_modules prg =
let obj_file = let obj_file =
Dynlink.adapt_filename Dynlink.adapt_filename
File.( File.(
(Pos.get_file (ModuleName.pos m) /../ ModuleName.to_string m) ^ ".cmo") (Pos.get_file (Mark.get (ModuleName.get_info m))
/../ ModuleName.to_string m) ^ ".cmo")
in in
if not (Sys.file_exists obj_file) then if not (Sys.file_exists obj_file) then
Message.raise_spanned_error Message.raise_spanned_error
~span_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here") ~span_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here")
(ModuleName.pos m) (Mark.get (ModuleName.get_info m))
"Compiled OCaml object %a not found. Make sure it has been suitably \ "Compiled OCaml object %a not found. Make sure it has been suitably \
compiled." compiled."
File.format obj_file File.format obj_file
@ -1003,20 +1003,18 @@ let load_runtime_modules prg =
obj_file Format.pp_print_text obj_file Format.pp_print_text
(Dynlink.error_message dl_err) (Dynlink.error_message dl_err)
in in
let rec aux loaded decl_ctx = let modules_list_topo =
ModuleName.Map.fold let rec aux acc (M mtree) =
(fun mname sub_decl_ctx loaded -> ModuleName.Map.fold
if ModuleName.Set.mem mname loaded then loaded (fun mname sub acc ->
else if List.exists (ModuleName.equal mname) acc then acc else
let loaded = ModuleName.Set.add mname loaded in mname :: aux acc sub)
let loaded = aux loaded sub_decl_ctx in mtree acc
load mname; in
loaded) List.rev (aux [] prg.decl_ctx.ctx_modules)
decl_ctx.ctx_modules loaded
in in
if not (ModuleName.Map.is_empty prg.decl_ctx.ctx_modules) then if modules_list_topo <> [] then
Message.emit_debug "Loading shared modules... %a" Message.emit_debug "Loading shared modules... %a"
(fun ppf -> ModuleName.Map.format_keys ppf) (Format.pp_print_list ~pp_sep:Format.pp_print_space ModuleName.format)
prg.decl_ctx.ctx_modules; modules_list_topo;
let (_loaded : ModuleName.Set.t) = aux ModuleName.Set.empty prg.decl_ctx in List.iter load modules_list_topo
()

View File

@ -74,7 +74,7 @@ module type EXPR_PARAM = sig
(** pre-processing on expressions: can be used to skip log calls, etc. *) (** pre-processing on expressions: can be used to skip log calls, etc. *)
end end
module ExprGen (C : EXPR_PARAM) : sig module ExprGen (_ : EXPR_PARAM) : sig
val expr : Format.formatter -> ('a, 't) gexpr -> unit val expr : Format.formatter -> ('a, 't) gexpr -> unit
end end

View File

@ -32,15 +32,14 @@ let empty_ctx =
{ {
ctx_enums = EnumName.Map.empty; ctx_enums = EnumName.Map.empty;
ctx_structs = StructName.Map.empty; ctx_structs = StructName.Map.empty;
ctx_struct_fields = Ident.Map.empty;
ctx_scopes = ScopeName.Map.empty; ctx_scopes = ScopeName.Map.empty;
ctx_topdefs = TopdefName.Map.empty; ctx_topdefs = TopdefName.Map.empty;
ctx_modules = ModuleName.Map.empty; ctx_struct_fields = Ident.Map.empty;
ctx_enum_constrs = Ident.Map.empty;
ctx_scope_index = Ident.Map.empty;
ctx_modules = M ModuleName.Map.empty;
} }
let module_ctx ctx path =
List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.ctx_modules) ctx path
let get_scope_body { code_items; _ } scope = let get_scope_body { code_items; _ } scope =
match match
Scope.fold_left ~init:None Scope.fold_left ~init:None

View File

@ -15,17 +15,12 @@
License for the specific language governing permissions and limitations under License for the specific language governing permissions and limitations under
the License. *) the License. *)
open Catala_utils
open Definitions open Definitions
(** {2 Program declaration context helpers} *) (** {2 Program declaration context helpers} *)
val empty_ctx : decl_ctx val empty_ctx : decl_ctx
val module_ctx : decl_ctx -> Uid.Path.t -> decl_ctx
(** Follows a path to get the corresponding context for type and value
declarations. *)
(** {2 Transformations} *) (** {2 Transformations} *)
val map_exprs : val map_exprs :

View File

@ -343,7 +343,6 @@ module Env = struct
scopes : A.typ A.ScopeVar.Map.t A.ScopeName.Map.t; scopes : A.typ A.ScopeVar.Map.t A.ScopeName.Map.t;
scopes_input : A.typ A.ScopeVar.Map.t A.ScopeName.Map.t; scopes_input : A.typ A.ScopeVar.Map.t A.ScopeName.Map.t;
toplevel_vars : A.typ A.TopdefName.Map.t; toplevel_vars : A.typ A.TopdefName.Map.t;
modules : 'e t A.ModuleName.Map.t;
} }
let empty (decl_ctx : A.decl_ctx) = let empty (decl_ctx : A.decl_ctx) =
@ -363,7 +362,6 @@ module Env = struct
scopes = A.ScopeName.Map.empty; scopes = A.ScopeName.Map.empty;
scopes_input = A.ScopeName.Map.empty; scopes_input = A.ScopeName.Map.empty;
toplevel_vars = A.TopdefName.Map.empty; toplevel_vars = A.TopdefName.Map.empty;
modules = A.ModuleName.Map.empty;
} }
let get t v = Var.Map.find_opt v t.vars let get t v = Var.Map.find_opt v t.vars
@ -374,9 +372,6 @@ module Env = struct
Option.bind (A.ScopeName.Map.find_opt scope t.scopes) (fun vmap -> Option.bind (A.ScopeName.Map.find_opt scope t.scopes) (fun vmap ->
A.ScopeVar.Map.find_opt var vmap) A.ScopeVar.Map.find_opt var vmap)
let module_env path env =
List.fold_left (fun env m -> A.ModuleName.Map.find m env.modules) env path
let add v tau t = { t with vars = Var.Map.add v tau t.vars } let add v tau t = { t with vars = Var.Map.add v tau t.vars }
let add_var v typ t = add v (ast_to_typ typ) t let add_var v typ t = add v (ast_to_typ typ) t
@ -393,19 +388,15 @@ module Env = struct
let add_toplevel_var v typ t = let add_toplevel_var v typ t =
{ t with toplevel_vars = A.TopdefName.Map.add v typ t.toplevel_vars } { t with toplevel_vars = A.TopdefName.Map.add v typ t.toplevel_vars }
let add_module modname ~module_env t =
{ t with modules = A.ModuleName.Map.add modname module_env t.modules }
let open_scope scope_name t = let open_scope scope_name t =
let scope_vars = let scope_vars =
A.ScopeVar.Map.union A.ScopeVar.Map.disjoint_union
(fun _ _ -> assert false)
t.scope_vars t.scope_vars
(A.ScopeName.Map.find scope_name t.scopes) (A.ScopeName.Map.find scope_name t.scopes)
in in
{ t with scope_vars } { t with scope_vars }
let rec dump ppf env = let dump ppf env =
let pp_sep = Format.pp_print_space in let pp_sep = Format.pp_print_space in
Format.pp_open_vbox ppf 0; Format.pp_open_vbox ppf 0;
(* Format.fprintf ppf "structs: @[<hov>%a@]@," (* Format.fprintf ppf "structs: @[<hov>%a@]@,"
@ -420,9 +411,6 @@ module Env = struct
Format.fprintf ppf "topdefs: @[<hov>%a@]@," Format.fprintf ppf "topdefs: @[<hov>%a@]@,"
(A.TopdefName.Map.format_keys ~pp_sep) (A.TopdefName.Map.format_keys ~pp_sep)
env.toplevel_vars; env.toplevel_vars;
Format.fprintf ppf "@[<hv 2>modules:@ %a@]"
(A.ModuleName.Map.format dump)
env.modules;
Format.pp_close_box ppf () Format.pp_close_box ppf ()
end end
@ -480,10 +468,8 @@ and typecheck_expr_top_down :
| DesugaredScopeVar { name; _ } | ScopelangScopeVar { name } -> | DesugaredScopeVar { name; _ } | ScopelangScopeVar { name } ->
Env.get_scope_var env (Mark.remove name) Env.get_scope_var env (Mark.remove name)
| SubScopeVar { scope; var; _ } -> | SubScopeVar { scope; var; _ } ->
let env = Env.module_env (A.ScopeName.path scope) env in
Env.get_subscope_out_var env scope (Mark.remove var) Env.get_subscope_out_var env scope (Mark.remove var)
| ToplevelVar { name } -> | ToplevelVar { name } ->
let env = Env.module_env (A.TopdefName.path (Mark.remove name)) env in
Env.get_toplevel_var env (Mark.remove name) Env.get_toplevel_var env (Mark.remove name)
in in
let ty = let ty =
@ -558,42 +544,39 @@ and typecheck_expr_top_down :
"This is not a structure, cannot access field %s (%a)" field "This is not a structure, cannot access field %s (%a)" field
(format_typ ctx) (ty e_struct') (format_typ ctx) (ty e_struct')
in in
let fld_ty = let str =
let str = try A.StructName.Map.find name env.structs
try A.StructName.Map.find name env.structs with A.StructName.Map.Not_found _ ->
with A.StructName.Map.Not_found _ -> Message.raise_spanned_error pos_e "No structure %a found"
Message.raise_spanned_error pos_e "No structure %a found" A.StructName.format name
A.StructName.format name in
in let field =
let field = let candidate_structs =
let ctx = Program.module_ctx ctx (A.StructName.path name) in try A.Ident.Map.find field ctx.ctx_struct_fields
let candidate_structs = with A.Ident.Map.Not_found _ ->
try A.Ident.Map.find field ctx.ctx_struct_fields
with A.Ident.Map.Not_found _ ->
Message.raise_spanned_error
(Expr.mark_pos context_mark)
"Field @{<yellow>\"%s\"@} does not belong to structure \
@{<yellow>\"%a\"@} (no structure defines it)"
field A.StructName.format name
in
try A.StructName.Map.find name candidate_structs
with A.StructName.Map.Not_found _ ->
Message.raise_spanned_error Message.raise_spanned_error
(Expr.mark_pos context_mark) (Expr.mark_pos context_mark)
"@[<hov>Field @{<yellow>\"%s\"@}@ does not belong to@ structure \ "Field @{<yellow>\"%s\"@} does not belong to structure \
@{<yellow>\"%a\"@},@ but to %a@]" @{<yellow>\"%a\"@} (no structure defines it)"
field A.StructName.format name field A.StructName.format name
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf "@ or@ ")
(fun fmt s_name ->
Format.fprintf fmt "@{<yellow>\"%a\"@}" A.StructName.format
s_name))
(A.StructName.Map.keys candidate_structs)
in in
A.StructField.Map.find field str try A.StructName.Map.find name candidate_structs
with A.StructName.Map.Not_found _ ->
Message.raise_spanned_error
(Expr.mark_pos context_mark)
"@[<hov>Field @{<yellow>\"%s\"@}@ does not belong to@ structure \
@{<yellow>\"%a\"@},@ but to %a@]"
field A.StructName.format name
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf "@ or@ ")
(fun fmt s_name ->
Format.fprintf fmt "@{<yellow>\"%a\"@}" A.StructName.format
s_name))
(A.StructName.Map.keys candidate_structs)
in in
let fld_ty = A.StructField.Map.find field str in
let mark = mark_with_tau_and_unify fld_ty in let mark = mark_with_tau_and_unify fld_ty in
Expr.edstructaccess ~e:e_struct' ~name_opt:(Some name) ~field mark Expr.estructaccess ~name ~e:e_struct' ~field mark
| A.EStructAccess { e = e_struct; name; field } -> | A.EStructAccess { e = e_struct; name; field } ->
let fld_ty = let fld_ty =
let str = let str =
@ -692,16 +675,11 @@ and typecheck_expr_top_down :
in in
Expr.ematch ~e:e1' ~name ~cases mark Expr.ematch ~e:e1' ~name ~cases mark
| A.EScopeCall { scope; args } -> | A.EScopeCall { scope; args } ->
let path = A.ScopeName.path scope in
let scope_out_struct = let scope_out_struct =
let ctx = Program.module_ctx ctx path in
(A.ScopeName.Map.find scope ctx.ctx_scopes).out_struct_name (A.ScopeName.Map.find scope ctx.ctx_scopes).out_struct_name
in in
let mark = mark_with_tau_and_unify (unionfind (TStruct scope_out_struct)) in let mark = mark_with_tau_and_unify (unionfind (TStruct scope_out_struct)) in
let vars = let vars = A.ScopeName.Map.find scope env.scopes_input in
let env = Env.module_env path env in
A.ScopeName.Map.find scope env.scopes_input
in
let args' = let args' =
A.ScopeVar.Map.mapi A.ScopeVar.Map.mapi
(fun name -> (fun name ->
@ -730,12 +708,6 @@ and typecheck_expr_top_down :
in in
Expr.evar (Var.translate v) (mark_with_tau_and_unify tau') Expr.evar (Var.translate v) (mark_with_tau_and_unify tau')
| A.EExternal { name } -> | A.EExternal { name } ->
let path =
match Mark.remove name with
| External_value td -> A.TopdefName.path td
| External_scope s -> A.ScopeName.path s
in
let ctx = Program.module_ctx ctx path in
let ty = let ty =
let not_found pr x = let not_found pr x =
Message.raise_spanned_error pos_e Message.raise_spanned_error pos_e

View File

@ -17,7 +17,6 @@
(** Typing for the default calculus. Because of the error terms, we perform type (** Typing for the default calculus. Because of the error terms, we perform type
inference using the classical W algorithm with union-find unification. *) inference using the classical W algorithm with union-find unification. *)
open Catala_utils
open Definitions open Definitions
module Env : sig module Env : sig
@ -35,8 +34,6 @@ module Env : sig
'e t -> 'e t ->
'e t 'e t
val add_module : ModuleName.t -> module_env:'e t -> 'e t -> 'e t
val module_env : Uid.Path.t -> 'e t -> 'e t
val open_scope : ScopeName.t -> 'e t -> 'e t val open_scope : ScopeName.t -> 'e t -> 'e t
val dump : Format.formatter -> 'e t -> unit val dump : Format.formatter -> 'e t -> unit
@ -62,7 +59,10 @@ val expr :
still done, but with unification with the existing annotations at every still done, but with unification with the existing annotations at every
step. This can be used for double-checking after AST transformations and step. This can be used for double-checking after AST transformations and
filling the gaps ([TAny]) if any. Use [Expr.untype] first if this is not filling the gaps ([TAny]) if any. Use [Expr.untype] first if this is not
what you want. *) what you want.
Note that typing also transparently performs disambiguation of constructors: [EDStructAccess] nodes are translated into [EStructAccess] with the suitable structure and field idents (this only concerns [desugared] expressions).
*)
val check_expr : val check_expr :
leave_unresolved:bool -> leave_unresolved:bool ->

View File

@ -312,15 +312,24 @@ and law_structure =
| LawText of (string[@opaque]) | LawText of (string[@opaque])
| CodeBlock of code_block * source_repr * bool (* Metadata if true *) | CodeBlock of code_block * source_repr * bool (* Metadata if true *)
and interface = uident Mark.pos * code_block and interface = {
(** Invariant: an interface shall only contain [*Decl] elements, or [Topdef] intf_modname: uident Mark.pos;
elements with [topdef_expr = None] *) intf_code: code_block;
(** Invariant: an interface shall only contain [*Decl] elements, or [Topdef]
elements with [topdef_expr = None] *)
intf_submodules: module_use list;
}
and module_use = {
mod_use_name: uident Mark.pos;
mod_use_alias: uident Mark.pos;
}
and program = { and program = {
program_module_name : uident Mark.pos option; program_module_name : uident Mark.pos option;
program_items : law_structure list; program_items : law_structure list;
program_source_files : (string[@opaque]) list; program_source_files : (string[@opaque]) list;
program_modules : interface list; (** Modules being used by the program *) program_used_modules : module_use list;
program_lang : Cli.backend_lang; [@opaque] program_lang : Cli.backend_lang; [@opaque]
} }

View File

@ -248,10 +248,8 @@ let rec parse_source (lexbuf : Sedlexing.lexbuf) : Ast.program =
let commands = localised_parser language lexbuf in let commands = localised_parser language lexbuf in
let program = expand_includes source_file_name commands in let program = expand_includes source_file_name commands in
{ {
program_module_name = program.Ast.program_module_name; program with
program_items = program.Ast.program_items;
program_source_files = source_file_name :: program.Ast.program_source_files; program_source_files = source_file_name :: program.Ast.program_source_files;
program_modules = program.program_modules;
program_lang = language; program_lang = language;
} }
@ -278,10 +276,12 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
Ast.program_module_name = join_module_names (Some id); Ast.program_module_name = join_module_names (Some id);
Ast.program_items = command :: acc.Ast.program_items; Ast.program_items = command :: acc.Ast.program_items;
} }
| Ast.ModuleUse (id, _alias) -> | Ast.ModuleUse (mod_use_name, alias) ->
let mod_use_alias = Option.value ~default:mod_use_name alias in
{ {
acc with acc with
Ast.program_modules = (id, []) :: acc.Ast.program_modules; Ast.program_used_modules = { mod_use_name; mod_use_alias }
:: acc.Ast.program_used_modules;
Ast.program_items = command :: acc.Ast.program_items; Ast.program_items = command :: acc.Ast.program_items;
} }
| Ast.LawInclude (Ast.CatalaFile inc_file) -> | Ast.LawInclude (Ast.CatalaFile inc_file) ->
@ -301,8 +301,8 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
] ]
"A file that declares a module cannot be used through the raw \ "A file that declares a module cannot be used through the raw \
'@{<yellow>> Include@}' directive. You should use it as a \ '@{<yellow>> Include@}' directive. You should use it as a \
module with '@{<yellow>> Use %a@}' instead." module with '@{<yellow>> Use @{<blue>%s@}@}' instead."
Uid.Module.format (Uid.Module.of_string id) (Mark.remove id)
in in
{ {
Ast.program_module_name = acc.program_module_name; Ast.program_module_name = acc.program_module_name;
@ -311,9 +311,9 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
acc.Ast.program_source_files; acc.Ast.program_source_files;
Ast.program_items = Ast.program_items =
List.rev_append includ_program.program_items acc.Ast.program_items; List.rev_append includ_program.program_items acc.Ast.program_items;
Ast.program_modules = Ast.program_used_modules =
List.rev_append includ_program.program_modules List.rev_append includ_program.program_used_modules
acc.Ast.program_modules; acc.Ast.program_used_modules;
Ast.program_lang = language; Ast.program_lang = language;
} }
| Ast.LawHeading (heading, commands') -> | Ast.LawHeading (heading, commands') ->
@ -321,7 +321,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
Ast.program_module_name; Ast.program_module_name;
Ast.program_items = commands'; Ast.program_items = commands';
Ast.program_source_files = new_sources; Ast.program_source_files = new_sources;
Ast.program_modules = new_modules; Ast.program_used_modules = new_used_modules;
Ast.program_lang = _; Ast.program_lang = _;
} = } =
expand_includes source_file commands' expand_includes source_file commands'
@ -332,8 +332,8 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
List.rev_append new_sources acc.Ast.program_source_files; List.rev_append new_sources acc.Ast.program_source_files;
Ast.program_items = Ast.program_items =
Ast.LawHeading (heading, commands') :: acc.Ast.program_items; Ast.LawHeading (heading, commands') :: acc.Ast.program_items;
Ast.program_modules = Ast.program_used_modules =
List.rev_append new_modules acc.Ast.program_modules; List.rev_append new_used_modules acc.Ast.program_used_modules;
Ast.program_lang = language; Ast.program_lang = language;
} }
| i -> { acc with Ast.program_items = i :: acc.Ast.program_items }) | i -> { acc with Ast.program_items = i :: acc.Ast.program_items })
@ -341,7 +341,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
Ast.program_module_name = None; Ast.program_module_name = None;
Ast.program_source_files = []; Ast.program_source_files = [];
Ast.program_items = []; Ast.program_items = [];
Ast.program_modules = []; Ast.program_used_modules = [];
Ast.program_lang = language; Ast.program_lang = language;
} }
commands commands
@ -351,7 +351,7 @@ and expand_includes (source_file : string) (commands : Ast.law_structure list) :
Ast.program_module_name = rprg.Ast.program_module_name; Ast.program_module_name = rprg.Ast.program_module_name;
Ast.program_source_files = List.rev rprg.Ast.program_source_files; Ast.program_source_files = List.rev rprg.Ast.program_source_files;
Ast.program_items = List.rev rprg.Ast.program_items; Ast.program_items = List.rev rprg.Ast.program_items;
Ast.program_modules = List.rev rprg.Ast.program_modules; Ast.program_used_modules = List.rev rprg.Ast.program_used_modules;
} }
(** {2 Handling interfaces} *) (** {2 Handling interfaces} *)
@ -360,7 +360,9 @@ let get_interface program =
let rec filter (req, acc) = function let rec filter (req, acc) = function
| Ast.LawInclude _ | Ast.LawText _ | Ast.ModuleDef _ -> req, acc | Ast.LawInclude _ | Ast.LawText _ | Ast.ModuleDef _ -> req, acc
| Ast.LawHeading (_, str) -> List.fold_left filter (req, acc) str | Ast.LawHeading (_, str) -> List.fold_left filter (req, acc) str
| Ast.ModuleUse (m, _) -> m :: req, acc | Ast.ModuleUse (mod_use_name, alias) ->
{ Ast.mod_use_name; mod_use_alias = Option.value ~default:mod_use_name alias }
:: req, acc
| Ast.CodeBlock (code, _, true) -> | Ast.CodeBlock (code, _, true) ->
( req, ( req,
List.fold_left List.fold_left
@ -394,9 +396,17 @@ let with_sedlex_source source_file f =
let load_interface source_file = let load_interface source_file =
let program = with_sedlex_source source_file parse_source in let program = with_sedlex_source source_file parse_source in
let modname = let modname =
match program.Ast.program_module_name with match program.Ast.program_module_name, source_file with
| Some mname -> mname | Some (mname, pos), Cli.FileName file ->
| None -> if File.(equal mname Filename.(remove_extension (basename file)))
then mname, pos
else
Message.raise_spanned_error pos
"Module declared as @{<blue>%s@}, which does not match the file name %a"
mname
File.format file
| Some mname, _ -> mname
| None, _ ->
Message.raise_error Message.raise_error
"%a doesn't define a module name. It should contain a '@{<cyan>> \ "%a doesn't define a module name. It should contain a '@{<cyan>> \
Module %s@}' directive." Module %s@}' directive."
@ -408,7 +418,9 @@ let load_interface source_file =
| _ -> "Module_name") | _ -> "Module_name")
in in
let used_modules, intf = get_interface program in let used_modules, intf = get_interface program in
(modname, intf), used_modules { Ast.intf_modname = modname;
Ast.intf_code = intf;
Ast.intf_submodules = used_modules; }
let parse_top_level_file (source_file : Cli.input_src) : Ast.program = let parse_top_level_file (source_file : Cli.input_src) : Ast.program =
let program = with_sedlex_source source_file parse_source in let program = with_sedlex_source source_file parse_source in

View File

@ -24,9 +24,9 @@ val lines :
(** Raw file parser that doesn't interpret any includes and returns the flat law (** Raw file parser that doesn't interpret any includes and returns the flat law
structure as is *) structure as is *)
val load_interface : Cli.input_src -> Ast.interface * string Mark.pos list val load_interface : Cli.input_src -> Ast.interface
(** Reads only declarations in metadata in the supplied input file, and only (** Reads only declarations in metadata in the supplied input file, and only
keeps type information ; returns the modules used as well *) keeps type information. The list of submodules is initialised with names only and empty contents. *)
val parse_top_level_file : Cli.input_src -> Ast.program val parse_top_level_file : Cli.input_src -> Ast.program
(** Parses a catala file (handling file includes) and returns a program. (** Parses a catala file (handling file includes) and returns a program.

2
dune
View File

@ -10,7 +10,7 @@
; don't stop building because of warnings ; don't stop building because of warnings
(dev (dev
(flags (flags
(:standard -warn-error -a))) (:standard -warn-error -a -w -67)))
; for CI runs: must fail on warnings ; for CI runs: must fail on warnings
(check (check
(flags (flags

View File

@ -4,7 +4,7 @@
declaration scope T: declaration scope T:
t1 scope Mod_middle.S t1 scope Mod_middle.S
# input i content Enum1 # input i content Enum1
output o1 content Mod_def.S output o1 content Mod_middle.Mod_def.S
output o2 content money output o2 content money
output o3 content money output o3 content money