diff --git a/compiler/catala_utils/uid.ml b/compiler/catala_utils/uid.ml index 3cdbaea1..81e35070 100644 --- a/compiler/catala_utils/uid.ml +++ b/compiler/catala_utils/uid.ml @@ -58,13 +58,14 @@ module Make (X : Info) () : Id with type info = X.info = struct { id = !counter; info } let get_info (uid : t) : X.info = uid.info - let format (fmt : Format.formatter) (x : t) : unit = X.format fmt x.info let hash (x : t) : int = x.id module Set = Set.Make (Ordering) module Map = Map.Make (Ordering) end +(* - Raw idents - *) + module MarkedString = struct type info = string Mark.pos @@ -75,3 +76,54 @@ module MarkedString = struct end module Gen () = Make (MarkedString) () + +(* - Modules, paths and qualified idents - *) + +module Module = struct + include String + let to_string m = m + let format ppf m = Format.fprintf ppf "@{%s@}" m + + let of_string m = m +end +(* TODO: should probably be turned into an uid once we implement module import + directives; that will incur an additional resolution work on all paths + though ([module Module = Gen ()]) *) + +module Path = struct + type t = Module.t list + + let format ppf p = + Format.pp_print_list + ~pp_sep:(fun _ () -> ()) + (fun ppf m -> + Format.fprintf ppf "%a@{.@}" Module.format m) + ppf p + + let to_string p = String.concat "." p + let equal = List.equal String.equal + let compare = List.compare String.compare +end + +module QualifiedMarkedString = struct + type info = Path.t * MarkedString.info + + let to_string (p, i) = + Format.asprintf "%a%a" Path.format p MarkedString.format i + let format fmt (p, i) = + Path.format fmt p; MarkedString.format fmt i + let equal (p1, i1) (p2, i2) = + Path.equal p1 p2 && MarkedString.equal i1 i2 + let compare (p1, i1) (p2, i2) = + match Path.compare p1 p2 with + | 0 -> MarkedString.compare i1 i2 + | n -> n +end + +module Gen_qualified () = struct + include Make (QualifiedMarkedString) () + + let fresh path t = fresh (path, t) + let path t = fst (get_info t) + let get_info t = snd (get_info t) +end diff --git a/compiler/catala_utils/uid.mli b/compiler/catala_utils/uid.mli index aa479267..196bb96b 100644 --- a/compiler/catala_utils/uid.mli +++ b/compiler/catala_utils/uid.mli @@ -60,3 +60,35 @@ module Make (X : Info) () : Id with type info = X.info module Gen () : Id with type info = MarkedString.info (** Shortcut for creating a kind of uids over marked strings *) + +(** {2 Handling of Uids with additional path information} *) + +module Module: sig + type t = private string (* TODO: this will become an uid at some point *) + val to_string: t -> string + val format: Format.formatter -> t -> unit + val equal : t -> t-> bool + val compare : t -> t -> int + val of_string: string -> t + + module Set : Set.S with type elt = t + module Map : Map.S with type key = t +end + +module Path: sig + type t = Module.t list + + val to_string: t -> string + val format: Format.formatter -> t -> unit + val equal : t -> t-> bool + val compare : t -> t -> int +end + +(** Same as [Gen] but also registers path information *) +module Gen_qualified () : sig + include Id with type info = Path.t * MarkedString.info + val fresh : Path.t -> MarkedString.info -> t + + val path : t -> Path.t + val get_info : t -> MarkedString.info +end diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index 15b0ae86..b4ed37ff 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -31,7 +31,7 @@ type scope_input_var_ctx = { type 'm scope_ref = | Local_scope_ref of 'm Ast.expr Var.t - | External_scope_ref of path * ScopeName.t Mark.pos + | External_scope_ref of ScopeName.t Mark.pos type 'm scope_sig_ctx = { scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *) @@ -73,15 +73,12 @@ let pos_mark_mk (type a m) (e : (a, m) gexpr) : let pos_mark_as e = pos_mark (Mark.get e) in pos_mark, pos_mark_as -let rec module_scope_sig scope_sig_ctx path scope = - match path with - | [] -> ScopeName.Map.find scope scope_sig_ctx.scope_sigs - | (modname, mpos) :: path -> ( - match ModuleName.Map.find_opt modname scope_sig_ctx.scope_sigs_modules with - | None -> - Message.raise_spanned_error mpos "Module %a not found" ModuleName.format - modname - | Some sig_ctx -> module_scope_sig sig_ctx path scope) +let module_scope_sig scope_sig_ctx scope = + let ssctx = + List.fold_left (fun ssctx m -> ModuleName.Map.find m ssctx.scope_sigs_modules) + scope_sig_ctx (ScopeName.path scope) + in + ScopeName.Map.find scope ssctx.scope_sigs let merge_defaults ~(is_func : bool) @@ -214,7 +211,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : let m = Mark.get e in match Mark.remove e with | EMatch { e = e1; name; cases = e_cases } -> - let path, enum_sig = EnumName.Map.find name ctx.decl_ctx.ctx_enums in + let enum_sig = EnumName.Map.find name ctx.decl_ctx.ctx_enums in let d_cases, remaining_e_cases = (* FIXME: these checks should probably be moved to a better place *) EnumConstructor.Map.fold @@ -223,9 +220,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : try EnumConstructor.Map.find constructor e_cases with EnumConstructor.Map.Not_found _ -> Message.raise_spanned_error (Expr.pos e) - "The constructor %a of enum %a%a is missing from this pattern \ + "The constructor %a of enum %a is missing from this pattern \ matching" - EnumConstructor.format constructor Print.path path + EnumConstructor.format constructor EnumName.format name in let case_d = translate_expr ctx case_e in @@ -236,16 +233,16 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : in if not (EnumConstructor.Map.is_empty remaining_e_cases) then Message.raise_spanned_error (Expr.pos e) - "Pattern matching is incomplete for enum %a%a: missing cases %a" - Print.path path EnumName.format name + "Pattern matching is incomplete for enum %a: missing cases %a" + EnumName.format name (EnumConstructor.Map.format_keys ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")) remaining_e_cases; let e1 = translate_expr ctx e1 in Expr.ematch ~e:e1 ~name ~cases:d_cases m - | EScopeCall { path; scope; args } -> + | EScopeCall { scope; args } -> let pos = Expr.mark_pos m in - let sc_sig = module_scope_sig ctx.scopes_parameters path scope in + let sc_sig = module_scope_sig ctx.scopes_parameters scope in let in_var_map = ScopeVar.Map.merge (fun var_name (str_field : scope_input_var_ctx option) expr -> @@ -300,8 +297,8 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : let e = match sc_sig.scope_sig_scope_ref with | Local_scope_ref v -> Expr.evar v m - | External_scope_ref (path, name) -> - Expr.eexternal ~path + | External_scope_ref name -> + Expr.eexternal ~name:(Mark.map (fun s -> External_scope s) name) m in @@ -411,9 +408,8 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : EndCall f_markings) ts_in (Expr.pos e) | _ -> original_field_expr) - (snd - (StructName.Map.find sc_sig.scope_sig_output_struct - ctx.decl_ctx.ctx_structs))) + (StructName.Map.find sc_sig.scope_sig_output_struct + ctx.decl_ctx.ctx_structs)) (Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e)) in (* Here we have to go through an if statement that records a decision being @@ -497,8 +493,8 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : ctx.subscope_vars |> SubScopeName.Map.find (Mark.remove alias) |> retrieve_in_and_out_typ_or_any var - | ELocation (ToplevelVar { path; name }) -> ( - let decl_ctx = Program.module_ctx ctx.decl_ctx path in + | ELocation (ToplevelVar { name }) -> ( + let decl_ctx = Program.module_ctx ctx.decl_ctx (TopdefName.path (Mark.remove name)) in let typ = TopdefName.Map.find (Mark.remove name) decl_ctx.ctx_topdefs in match Mark.remove typ with | TArrow (tin, (tout, _)) -> List.map Mark.remove tin, tout @@ -572,11 +568,13 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : %a's results. Maybe you forgot to qualify it as an output?" SubScopeName.format (Mark.remove s) ScopeVar.format (Mark.remove a) SubScopeName.format (Mark.remove s)) - | ELocation (ToplevelVar { path = []; name }) -> - let v, _ = TopdefName.Map.find (Mark.remove name) ctx.toplevel_vars in - Expr.evar v m - | ELocation (ToplevelVar { path = _ :: _ as path; name }) -> - Expr.eexternal ~path ~name:(Mark.map (fun n -> External_value n) name) m + | ELocation (ToplevelVar { name }) -> + let path = TopdefName.path (Mark.remove name) in + if path = [] then + let v, _ = TopdefName.Map.find (Mark.remove name) ctx.toplevel_vars in + Expr.evar v m + else + Expr.eexternal ~name:(Mark.map (fun n -> External_value n) name) m | EOp { op = Add_dat_dur _; tys } -> Expr.eop (Add_dat_dur ctx.date_rounding) tys m | EOp { op; tys } -> Expr.eop (Operator.translate op) tys m @@ -710,11 +708,11 @@ let translate_rule (* A global variable can't be defined locally. The [Definition] constructor could be made more specific to avoid this case, but the added complexity didn't seem worth it *) - | Call ((path, subname), subindex, m) -> - let subscope_sig = module_scope_sig ctx.scopes_parameters path subname in + | Call (subname, subindex, m) -> + let subscope_sig = module_scope_sig ctx.scopes_parameters subname in let scope_sig_decl = ScopeName.Map.find subname - (Program.module_ctx ctx.decl_ctx path).ctx_scopes + (Program.module_ctx ctx.decl_ctx (ScopeName.path subname)).ctx_scopes in let all_subscope_vars = subscope_sig.scope_sig_local_vars in let all_subscope_input_vars = @@ -736,8 +734,8 @@ let translate_rule let m = mark_tany m pos_call in match subscope_sig.scope_sig_scope_ref with | Local_scope_ref var -> Expr.make_var var m - | External_scope_ref (path, name) -> - Expr.eexternal ~path ~name:(Mark.map (fun n -> External_scope n) name) m + | External_scope_ref name -> + Expr.eexternal ~name:(Mark.map (fun n -> External_scope n) name) m in let called_scope_input_struct = subscope_sig.scope_sig_input_struct in let called_scope_return_struct = subscope_sig.scope_sig_output_struct in @@ -1069,7 +1067,7 @@ let translate_scope_decl StructField.Map.empty scope_input_variables in let new_struct_ctx = - StructName.Map.singleton scope_input_struct_name ([], field_map) + StructName.Map.singleton scope_input_struct_name field_map in scope_body, new_struct_ctx @@ -1088,18 +1086,18 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = prg.Scopelang.Ast.program_scopes)) prgm.program_modules; let sctx : 'm scope_sigs_ctx = - let process_scope_sig (scope_path, scope_name) scope = - Message.emit_debug "process_scope_sig %a%a (%a)" Print.path scope_path + let process_scope_sig scope_name scope = + Message.emit_debug "process_scope_sig %a (%a)" ScopeName.format scope_name ScopeName.format scope.Scopelang.Ast.scope_decl_name; + let scope_path = ScopeName.path scope_name in let scope_ref = - match scope_path with - | [] -> + if scope_path = [] then let v = Var.make (Mark.remove (ScopeName.get_info scope_name)) in Local_scope_ref v - | path -> + else External_scope_ref - (path, Mark.copy (ScopeName.get_info scope_name) scope_name) + (Mark.copy (ScopeName.get_info scope_name) scope_name) in let scope_info = try @@ -1108,7 +1106,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = with ScopeName.Map.Not_found _ -> Message.raise_spanned_error (Mark.get (ScopeName.get_info scope_name)) - "Could not find scope %a%a" Print.path scope_path ScopeName.format + "Could not find scope %a" ScopeName.format scope_name in let scope_sig_in_fields = @@ -1148,17 +1146,15 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = scope_sig_in_fields; } in - let rec process_modules path prg = + let rec process_modules prg = { scope_sigs = ScopeName.Map.mapi (fun scope_name (scope_decl, _) -> - process_scope_sig (path, scope_name) scope_decl) + process_scope_sig scope_name scope_decl) prg.Scopelang.Ast.program_scopes; scope_sigs_modules = - ModuleName.Map.mapi - (fun modname prg -> - process_modules (path @ [modname, Pos.no_pos]) prg) + ModuleName.Map.map process_modules prg.Scopelang.Ast.program_modules; } in @@ -1166,21 +1162,20 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = scope_sigs = ScopeName.Map.mapi (fun scope_name (scope_decl, _) -> - process_scope_sig ([], scope_name) scope_decl) + process_scope_sig scope_name scope_decl) prgm.Scopelang.Ast.program_scopes; scope_sigs_modules = - ModuleName.Map.mapi - (fun modname prg -> process_modules [modname, Pos.no_pos] prg) + ModuleName.Map.map + process_modules prgm.Scopelang.Ast.program_modules; } in - let rec gather_module_in_structs acc path sctx = + let rec gather_module_in_structs acc sctx = (* Expose all added in_structs from submodules at toplevel *) ModuleName.Map.fold - (fun modname scope_sigs acc -> - let path = path @ [modname, Pos.no_pos] in + (fun _ scope_sigs acc -> let acc = - gather_module_in_structs acc path scope_sigs.scope_sigs_modules + gather_module_in_structs acc scope_sigs.scope_sigs_modules in ScopeName.Map.fold (fun _ scope_sig_ctx acc -> @@ -1196,7 +1191,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = scope_sig_ctx.scope_sig_in_fields StructField.Map.empty in StructName.Map.add scope_sig_ctx.scope_sig_input_struct - (path, fields) acc) + fields acc) scope_sigs.scope_sigs acc) sctx acc in @@ -1204,7 +1199,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = { decl_ctx with ctx_structs = - gather_module_in_structs decl_ctx.ctx_structs [] sctx.scope_sigs_modules; + gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules; } in let top_ctx = diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 6569579f..1ef669eb 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -215,7 +215,7 @@ type var_or_states = WholeVar | States of StateName.t list type scope = { scope_vars : var_or_states ScopeVar.Map.t; - scope_sub_scopes : (path * ScopeName.t) SubScopeName.Map.t; + scope_sub_scopes : ScopeName.t SubScopeName.Map.t; scope_uid : ScopeName.t; scope_defs : scope_def ScopeDef.Map.t; scope_assertions : assertion AssertionName.Map.t; diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index 77dd3054..bde89f85 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -104,7 +104,7 @@ type var_or_states = WholeVar | States of StateName.t list type scope = { scope_vars : var_or_states ScopeVar.Map.t; - scope_sub_scopes : (path * ScopeName.t) SubScopeName.Map.t; + scope_sub_scopes : ScopeName.t SubScopeName.Map.t; scope_uid : ScopeName.t; scope_defs : scope_def ScopeDef.Map.t; scope_assertions : assertion AssertionName.Map.t; diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index 05b4761f..22f8d558 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -169,6 +169,7 @@ let rec disambiguate_constructor Message.raise_spanned_error pos "Enum %s does not contain case %s" (Mark.remove enum) (Mark.remove constructor)) | (modname, mpos) :: path -> ( + let modname = ModuleName.of_string modname in match ModuleName.Map.find_opt modname ctxt.modules with | None -> Message.raise_spanned_error mpos "Module %a not found" ModuleName.format @@ -373,7 +374,7 @@ let rec translate_expr | Some v -> Expr.elocation (ToplevelVar - { path = []; name = v, Mark.get (TopdefName.get_info v) }) + { name = v, Mark.get (TopdefName.get_info v) }) emark | None -> Name_resolution.raise_unknown_identifier @@ -383,7 +384,7 @@ let rec translate_expr match Ident.Map.find_opt (Mark.remove name) ctxt.topdefs with | Some v -> Expr.elocation - (ToplevelVar { path; name = v, Mark.get (TopdefName.get_info v) }) + (ToplevelVar { name = v, Mark.get (TopdefName.get_info v) }) emark | None -> Name_resolution.raise_unknown_identifier "for an external variable" name) @@ -393,19 +394,17 @@ let rec translate_expr when Option.fold scope ~none:false ~some:(fun s -> Name_resolution.is_subscope_uid s ctxt y) -> (* In this case, y.x is a subscope variable *) - let subscope_uid, (subscope_path, subscope_real_uid) = + let subscope_uid, subscope_real_uid = match Ident.Map.find y scope_vars with | SubScope (sub, sc) -> sub, sc | ScopeVar _ -> assert false in let subscope_var_uid = - let ctxt = Name_resolution.module_ctx ctxt subscope_path in Name_resolution.get_var_uid subscope_real_uid ctxt x in Expr.elocation (SubScopeVar { - path = subscope_path; scope = subscope_real_uid; alias = subscope_uid, pos; var = subscope_var_uid, pos; @@ -418,6 +417,7 @@ let rec translate_expr | [] -> None | [c] -> Some (Name_resolution.get_struct ctxt c) | (modname, mpos) :: path -> ( + let modname = ModuleName.of_string modname in match ModuleName.Map.find_opt modname ctxt.modules with | None -> Message.raise_spanned_error mpos "Module %a not found" @@ -425,7 +425,7 @@ let rec translate_expr | Some ctxt -> get_str ctxt path) in Expr.edstructaccess ~e ~field:(Mark.remove x) - ~name_opt:(get_str ctxt path) ~path emark) + ~name_opt:(get_str ctxt path) emark) | FunCall (f, args) -> Expr.eapp (rec_helper f) (List.map rec_helper args) emark | ScopeCall (((path, id), _), fields) -> @@ -467,7 +467,7 @@ let rec translate_expr acc) ScopeVar.Map.empty fields in - Expr.escopecall ~path ~scope:called_scope ~args:in_struct emark + Expr.escopecall ~scope:called_scope ~args:in_struct emark | LetIn (x, e1, e2) -> let v = Var.make (Mark.remove x) in let local_vars = Ident.Map.add (Mark.remove x) v local_vars in @@ -1391,11 +1391,11 @@ let init_scope_defs (scope_def_map, 0) states in scope_def) - | Name_resolution.SubScope (v0, (path, subscope_uid)) -> - let ctxt = Name_resolution.module_ctx ctxt path in + | Name_resolution.SubScope (v0, subscope_uid) -> let sub_scope_def = - ScopeName.Map.find subscope_uid ctxt.Name_resolution.scopes + Name_resolution.get_scope_context ctxt subscope_uid in + let ctxt = List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.Name_resolution.modules) ctxt (ScopeName.path subscope_uid) in Ident.Map.fold (fun _ v scope_def_map -> match v with @@ -1469,34 +1469,25 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : Ast.program_ctx = { (* After name resolution, type definitions (structs and enums) are - exposed at toplevel for easier lookup, but their paths need to - remain available for printing and later passes *) + exposed at toplevel for easier lookup *) ctx_structs = ModuleName.Map.fold - (fun modname prg acc -> + (fun _ prg acc -> StructName.Map.union (fun _ _ _ -> assert false) acc - (StructName.Map.map - (fun (path, def) -> (modname, Pos.no_pos) :: path, def) - prg.Ast.program_ctx.ctx_structs)) + prg.Ast.program_ctx.ctx_structs) submodules - (StructName.Map.map - (fun def -> [], def) - ctxt.Name_resolution.structs); + ctxt.Name_resolution.structs; ctx_enums = ModuleName.Map.fold - (fun modname prg acc -> + (fun _ prg acc -> EnumName.Map.union (fun _ _ _ -> assert false) acc - (EnumName.Map.map - (fun (path, def) -> (modname, Pos.no_pos) :: path, def) - prg.Ast.program_ctx.ctx_enums)) + prg.Ast.program_ctx.ctx_enums) submodules - (EnumName.Map.map - (fun def -> [], def) - ctxt.Name_resolution.enums); + ctxt.Name_resolution.enums; ctx_scopes = Ident.Map.fold (fun _ def acc -> @@ -1546,10 +1537,11 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : let desugared = List.fold_left (fun acc (id, intf) -> + let id = ModuleName.of_string id in let modul = ModuleName.Map.find id acc.Ast.program_modules in let modul = process_code_block - (Name_resolution.module_ctx ctxt [id, Pos.no_pos]) + (ModuleName.Map.find id ctxt.modules) modul intf in { diff --git a/compiler/desugared/linting.ml b/compiler/desugared/linting.ml index a4d9b2af..e45b717a 100644 --- a/compiler/desugared/linting.ml +++ b/compiler/desugared/linting.ml @@ -109,7 +109,7 @@ let detect_unused_struct_fields (p : program) : unit = let rec structs_fields_used_expr e struct_fields_used = match Mark.remove e with | EDStructAccess - { name_opt = Some name; e = e_struct; field; path = _ } -> + { name_opt = Some name; e = e_struct; field } -> let field = StructName.Map.find name (Ident.Map.find field p.program_ctx.ctx_struct_fields) @@ -136,8 +136,8 @@ let detect_unused_struct_fields (p : program) : unit = p.program_ctx.ctx_scopes StructField.Set.empty in StructName.Map.iter - (fun s_name (path, fields) -> - if path <> [] then () + (fun s_name fields -> + if StructName.path s_name <> [] then () else if (not (StructField.Map.is_empty fields)) && StructField.Map.for_all @@ -192,8 +192,8 @@ let detect_unused_enum_constructors (p : program) : unit = ~init:EnumConstructor.Set.empty p in EnumName.Map.iter - (fun e_name (path, constructors) -> - if path <> [] then () + (fun e_name constructors -> + if EnumName.path e_name <> [] then () else if EnumConstructor.Map.for_all (fun cons _ -> diff --git a/compiler/desugared/name_resolution.ml b/compiler/desugared/name_resolution.ml index bea4c5d4..43c882e0 100644 --- a/compiler/desugared/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -32,7 +32,7 @@ type scope_def_context = { type scope_var_or_subscope = | ScopeVar of ScopeVar.t - | SubScope of SubScopeName.t * (path * ScopeName.t) + | SubScope of SubScopeName.t * ScopeName.t type scope_context = { var_idmap : scope_var_or_subscope Ident.Map.t; @@ -68,6 +68,7 @@ type typedef = | TScope of ScopeName.t * scope_info (** Implicitly defined output struct *) type context = { + path : Uid.Path.t; typedefs : typedef Ident.Map.t; (** Gathers the names of the scopes, structs and enums *) field_idmap : StructField.t StructName.Map.t Ident.Map.t; @@ -112,12 +113,23 @@ let get_var_io (ctxt : context) (uid : ScopeVar.t) : Surface.Ast.scope_decl_context_io = (ScopeVar.Map.find uid ctxt.var_typs).var_sig_io +let get_scope_context (ctxt: context) (scope: ScopeName.t) : scope_context = + let rec remove_common_prefix curpath scpath = match curpath, scpath with + | m1 :: cp, m2 :: sp when ModuleName.equal m1 m2 -> remove_common_prefix cp sp + | _ -> scpath + in + let path = remove_common_prefix ctxt.path (ScopeName.path scope) in + let ctxt = + List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.modules) ctxt path + in + ScopeName.Map.find scope ctxt.scopes + (** Get the variable uid inside the scope given in argument *) let get_var_uid (scope_uid : ScopeName.t) (ctxt : context) ((x, pos) : Ident.t Mark.pos) : ScopeVar.t = - let scope = ScopeName.Map.find scope_uid ctxt.scopes in + let scope = get_scope_context ctxt scope_uid in match Ident.Map.find_opt x scope.var_idmap with | Some (ScopeVar uid) -> uid | _ -> @@ -130,7 +142,7 @@ let get_subscope_uid (scope_uid : ScopeName.t) (ctxt : context) ((y, pos) : Ident.t Mark.pos) : SubScopeName.t = - let scope = ScopeName.Map.find scope_uid ctxt.scopes in + let scope = get_scope_context ctxt scope_uid in match Ident.Map.find_opt y scope.var_idmap with | Some (SubScope (sub_uid, _sub_id)) -> sub_uid | _ -> raise_unknown_identifier "for a subscope of this scope" (y, pos) @@ -139,7 +151,7 @@ let get_subscope_uid subscopes of [scope_uid]. *) let is_subscope_uid (scope_uid : ScopeName.t) (ctxt : context) (y : Ident.t) : bool = - let scope = ScopeName.Map.find scope_uid ctxt.scopes in + let scope = get_scope_context ctxt scope_uid in match Ident.Map.find_opt y scope.var_idmap with | Some (SubScope _) -> true | _ -> false @@ -147,7 +159,7 @@ let is_subscope_uid (scope_uid : ScopeName.t) (ctxt : context) (y : Ident.t) : (** Checks if the var_uid belongs to the scope scope_uid *) let belongs_to (ctxt : context) (uid : ScopeVar.t) (scope_uid : ScopeName.t) : bool = - let scope = ScopeName.Map.find scope_uid ctxt.scopes in + let scope = get_scope_context ctxt scope_uid in Ident.Map.exists (fun _ -> function | ScopeVar var_uid -> ScopeVar.equal uid var_uid @@ -241,6 +253,7 @@ let rec module_ctx ctxt path = match path with | [] -> ctxt | (modname, mpos) :: path -> ( + let modname = ModuleName.of_string modname in match ModuleName.Map.find_opt modname ctxt.modules with | None -> Message.raise_spanned_error mpos "Module %a not found" ModuleName.format @@ -256,7 +269,7 @@ let process_subscope_decl (decl : Surface.Ast.scope_decl_context_scope) : context = let name, name_pos = decl.scope_decl_context_scope_name in let (path, subscope), s_pos = decl.scope_decl_context_scope_sub_scope in - let scope_ctxt = ScopeName.Map.find scope ctxt.scopes in + let scope_ctxt = get_scope_context ctxt scope in match Ident.Map.find_opt (Mark.remove subscope) scope_ctxt.var_idmap with | Some use -> let info = @@ -278,7 +291,7 @@ let process_subscope_decl scope_ctxt with var_idmap = Ident.Map.add name - (SubScope (sub_scope_uid, (path, original_subscope_uid))) + (SubScope (sub_scope_uid, original_subscope_uid)) scope_ctxt.var_idmap; sub_scopes = ScopeName.Set.add original_subscope_uid scope_ctxt.sub_scopes; @@ -324,6 +337,7 @@ let rec process_base_typ declared" ident) | Surface.Ast.Named ((modul, mpos) :: path, id) -> ( + let modul = ModuleName.of_string modul in match ModuleName.Map.find_opt modul ctxt.modules with | None -> Message.raise_spanned_error mpos @@ -351,7 +365,7 @@ let process_data_decl let data_typ = process_type ctxt decl.scope_decl_context_item_typ in let is_cond = is_type_cond decl.scope_decl_context_item_typ in let name, pos = decl.scope_decl_context_item_name in - let scope_ctxt = ScopeName.Map.find scope ctxt.scopes in + let scope_ctxt = get_scope_context ctxt scope in match Ident.Map.find_opt name scope_ctxt.var_idmap with | Some use -> let info = @@ -568,7 +582,7 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : } in let out_struct_fields = - let sco = ScopeName.Map.find scope_uid ctxt.scopes in + let sco = get_scope_context ctxt scope_uid in let str = get_struct ctxt decl.scope_decl_name in Ident.Map.fold (fun id var svmap -> @@ -621,9 +635,9 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : (fun use -> raise_already_defined_error (typedef_info use) name pos "scope") (Ident.Map.find_opt name ctxt.typedefs); - let scope_uid = ScopeName.fresh (name, pos) in - let in_struct_name = StructName.fresh (name ^ "_in", pos) in - let out_struct_name = StructName.fresh (name, pos) in + let scope_uid = ScopeName.fresh ctxt.path (name, pos) in + let in_struct_name = StructName.fresh ctxt.path (name ^ "_in", pos) in + let out_struct_name = StructName.fresh ctxt.path (name, pos) in { ctxt with typedefs = @@ -651,7 +665,7 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : (fun use -> raise_already_defined_error (typedef_info use) name pos "struct") (Ident.Map.find_opt name ctxt.typedefs); - let s_uid = StructName.fresh sdecl.struct_decl_name in + let s_uid = StructName.fresh ctxt.path sdecl.struct_decl_name in { ctxt with typedefs = @@ -665,7 +679,7 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : (fun use -> raise_already_defined_error (typedef_info use) name pos "enum") (Ident.Map.find_opt name ctxt.typedefs); - let e_uid = EnumName.fresh edecl.enum_decl_name in + let e_uid = EnumName.fresh ctxt.path edecl.enum_decl_name in { ctxt with typedefs = @@ -681,7 +695,7 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : raise_already_defined_error (TopdefName.get_info use) name pos "toplevel definition") (Ident.Map.find_opt name ctxt.topdefs); - let uid = TopdefName.fresh def.topdef_name in + let uid = TopdefName.fresh ctxt.path def.topdef_name in { ctxt with topdefs = Ident.Map.add name uid ctxt.topdefs; @@ -762,8 +776,8 @@ let get_def_key ScopeVar.format x_uid else None ) | [y; x] -> - let (subscope_uid, (path, subscope_real_uid)) - : SubScopeName.t * (path * ScopeName.t) = + let (subscope_uid, subscope_real_uid) + : SubScopeName.t * ScopeName.t = match Ident.Map.find_opt (Mark.remove y) scope_ctxt.var_idmap with | Some (SubScope (v, u)) -> v, u | Some _ -> @@ -775,7 +789,6 @@ let get_def_key Print.lit_style (Mark.remove y) in let x_uid = - let ctxt = module_ctx ctxt path in get_var_uid subscope_real_uid ctxt x in Ast.ScopeDef.SubScopeVar (subscope_uid, x_uid, pos) @@ -924,6 +937,7 @@ let process_use_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : let empty_ctxt = { + path = []; typedefs = Ident.Map.empty; scopes = ScopeName.Map.empty; topdefs = Ident.Map.empty; @@ -937,13 +951,14 @@ let empty_ctxt = } let import_module modules (name, intf) = - let ctxt = { empty_ctxt with modules } in + let mname = ModuleName.of_string name in + let ctxt = { empty_ctxt with modules; path = [mname] } in let ctxt = List.fold_left process_name_item ctxt intf in let ctxt = List.fold_left process_decl_item ctxt intf in let ctxt = { ctxt with modules = empty_ctxt.modules } in (* No submodules at the moment, a module may use the ones loaded before it, but doesn't reexport them *) - ModuleName.Map.add name ctxt modules + ModuleName.Map.add mname ctxt modules (** Derive the context from metadata, in one pass over the declarations *) let form_context (prgm : Surface.Ast.program) : context = diff --git a/compiler/desugared/name_resolution.mli b/compiler/desugared/name_resolution.mli index 1267e27d..fa2c0206 100644 --- a/compiler/desugared/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -32,7 +32,7 @@ type scope_def_context = { type scope_var_or_subscope = | ScopeVar of ScopeVar.t - | SubScope of SubScopeName.t * (path * ScopeName.t) + | SubScope of SubScopeName.t * ScopeName.t type scope_context = { var_idmap : scope_var_or_subscope Ident.Map.t; @@ -68,6 +68,8 @@ type typedef = | TScope of ScopeName.t * scope_info (** Implicitly defined output struct *) type context = { + path : ModuleName.t list; + (** The current path being processed. Used for generating the Uids. *) typedefs : typedef Ident.Map.t; (** Gathers the names of the scopes, structs and enums *) field_idmap : StructField.t StructName.Map.t Ident.Map.t; @@ -105,6 +107,9 @@ val get_var_typ : context -> ScopeVar.t -> typ val is_var_cond : context -> ScopeVar.t -> bool val get_var_io : context -> ScopeVar.t -> Surface.Ast.scope_decl_context_io +val get_scope_context : context -> ScopeName.t -> scope_context +(** Get the corresponding scope context from the context, looking up into nested submodules as necessary, following the path information in the scope name *) + val get_var_uid : ScopeName.t -> context -> Ident.t Mark.pos -> ScopeVar.t (** Get the variable uid inside the scope given in argument *) @@ -151,7 +156,7 @@ val get_scope : context -> Ident.t Mark.pos -> ScopeName.t (** Find a scope definition from the typedefs, failing if there is none or it has a different kind *) -val module_ctx : context -> path -> context +val module_ctx : context -> Surface.Ast.path -> context (** Returns the context corresponding to the given module path; raises a user error if the module is not found *) diff --git a/compiler/driver.ml b/compiler/driver.ml index 7258a5cd..15ac2186 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -255,7 +255,7 @@ module Commands = struct variable ScopeName.format scope_uid | Some (Desugared.Name_resolution.SubScope - (subscope_var_name, (subscope_path, subscope_name))) -> ( + (subscope_var_name, subscope_name)) -> ( match second_part with | None -> Message.raise_error @@ -265,7 +265,10 @@ module Commands = struct SubScopeName.format subscope_var_name ScopeName.format scope_uid | Some second_part -> ( match - let ctxt = Desugared.Name_resolution.module_ctx ctxt subscope_path in + let ctxt = Desugared.Name_resolution.module_ctx ctxt + (List.map (fun m -> ModuleName.to_string m, Pos.no_pos) + (ScopeName.path subscope_name)) + in Ident.Map.find_opt second_part (ScopeName.Map.find subscope_name ctxt.scopes).var_idmap with diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 510ff89d..e12cfcd3 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -333,11 +333,11 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = | TEnum e -> EnumConstructor.Map.exists (fun _ t' -> type_contains_arrow t') - (snd (EnumName.Map.find e p.decl_ctx.ctx_enums)) + (EnumName.Map.find e p.decl_ctx.ctx_enums) | TStruct s -> StructField.Map.exists (fun _ t' -> type_contains_arrow t') - (snd (StructName.Map.find s p.decl_ctx.ctx_structs)) + (StructName.Map.find s p.decl_ctx.ctx_structs) in let replace_fun_typs t = if type_contains_arrow t then Mark.copy t TAny else t @@ -346,11 +346,11 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = p.decl_ctx with ctx_structs = StructName.Map.map - (fun (p, def) -> p, StructField.Map.map replace_fun_typs def) + (StructField.Map.map replace_fun_typs) p.decl_ctx.ctx_structs; ctx_enums = EnumName.Map.map - (fun (p, def) -> p, EnumConstructor.Map.map replace_fun_typs def) + (EnumConstructor.Map.map replace_fun_typs) p.decl_ctx.ctx_enums; } in @@ -552,7 +552,7 @@ let rec hoist_closures_code_item_list (fun next_code_items closure -> Cons ( Topdef - ( TopdefName.fresh + ( TopdefName.fresh [] ( Bindlib.name_of hoisted_closure.name, Expr.mark_pos closure_mark ), hoisted_closure.ty, diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index 3d37ef29..694025c3 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -745,7 +745,7 @@ let translate_program (prgm : typed D.program) : untyped A.program = prgm.decl_ctx with ctx_enums = prgm.decl_ctx.ctx_enums - |> EnumName.Map.add Expr.option_enum ([], Expr.option_enum_config); + |> EnumName.Map.add Expr.option_enum Expr.option_enum_config; } in let decl_ctx = @@ -753,8 +753,8 @@ let translate_program (prgm : typed D.program) : untyped A.program = decl_ctx with ctx_structs = prgm.decl_ctx.ctx_structs - |> StructName.Map.mapi (fun _n (path, str) -> - path, StructField.Map.map trans_typ_keep str); + |> StructName.Map.mapi (fun _n str -> + StructField.Map.map trans_typ_keep str); } in diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 19ead856..8641ea2a 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -274,8 +274,7 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : in match Mark.remove e with | EVar v -> Format.fprintf fmt "%a" format_var v - | EExternal { path; name } -> ( - Print.path fmt path; + | EExternal { name } -> ( (* FIXME: this is wrong in general !! We assume the idents exposed by the module depend only on the original name, while they actually get through Bindlib and may have been renamed. A correct implem could use the runtime @@ -555,11 +554,13 @@ let format_ctx (fun struct_or_enum -> match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> - let path, def = StructName.Map.find s ctx.ctx_structs in - if path = [] then Format.fprintf fmt "%a@\n" format_struct_decl (s, def) + let def = StructName.Map.find s ctx.ctx_structs in + if StructName.path s = [] then + Format.fprintf fmt "%a@\n" format_struct_decl (s, def) | Scopelang.Dependency.TVertex.Enum e -> - let path, def = EnumName.Map.find e ctx.ctx_enums in - if path = [] then Format.fprintf fmt "%a@\n" format_enum_decl (e, def)) + let def = EnumName.Map.find e ctx.ctx_enums in + if EnumName.path e = [] then + Format.fprintf fmt "%a@\n" format_enum_decl (e, def)) (type_ordering @ scope_structs) let rename_vars e = @@ -618,7 +619,7 @@ let format_scope_exec scope_body = let scope_name_str = Mark.remove (ScopeName.get_info scope_name) in let scope_var = String.Map.find scope_name_str bnd in - let _, scope_input = + let scope_input = StructName.Map.find scope_body.scope_body_input_struct ctx.ctx_structs in if not (StructField.Map.is_empty scope_input) then diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index 5fc2589a..62c42fbe 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -139,10 +139,9 @@ module To_jsoo = struct | TArrow _ -> Format.fprintf fmt "Js.meth" | _ -> Format.fprintf fmt "Js.readonly_prop" in - let format_struct_decl fmt (struct_name, (path, struct_fields)) = + let format_struct_decl fmt (struct_name, struct_fields) = let fmt_struct_name fmt _ = format_struct_name fmt struct_name in let fmt_module_struct_name fmt _ = - Print.path fmt path; To_ocaml.format_to_module_name fmt (`Sname struct_name) in let fmt_to_jsoo fmt _ = @@ -233,10 +232,9 @@ module To_jsoo = struct in let format_enum_decl fmt - (enum_name, (path, (enum_cons : typ EnumConstructor.Map.t))) = + (enum_name, (enum_cons : typ EnumConstructor.Map.t)) = let fmt_enum_name fmt _ = format_enum_name fmt enum_name in let fmt_module_enum_name fmt () = - Print.path fmt path; To_ocaml.format_to_module_name fmt (`Ename enum_name) in let fmt_to_jsoo fmt _ = diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index f853881a..ace37150 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -76,11 +76,11 @@ module To_json = struct (ctx : decl_ctx) (fmt : Format.formatter) (sname : StructName.t) = - let path, fields = StructName.Map.find sname ctx.ctx_structs in + let fields = StructName.Map.find sname ctx.ctx_structs in Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n") (fun fmt (field_name, field_type) -> - Format.fprintf fmt "@[\"%a%a\": {@\n%a@]@\n}" Print.path path + Format.fprintf fmt "@[\"%a\": {@\n%a@]@\n}" format_struct_field_name_camel_case field_name fmt_type field_type) fmt (StructField.Map.bindings fields) @@ -105,18 +105,17 @@ module To_json = struct | TEnum e -> List.fold_left collect (t :: acc) (EnumConstructor.Map.values - (snd (EnumName.Map.find e ctx.ctx_enums))) + (EnumName.Map.find e ctx.ctx_enums)) | TArray t -> collect acc t | _ -> acc in StructName.Map.find input_struct ctx.ctx_structs - |> snd |> StructField.Map.values |> List.fold_left (fun acc field_typ -> collect acc field_typ) [] |> List.sort_uniq (fun t t' -> String.compare (get_name t) (get_name t')) in let fmt_enum_properties fmt ename = - let _path, enum_def = EnumName.Map.find ename ctx.ctx_enums in + let enum_def = EnumName.Map.find ename ctx.ctx_enums in Format.fprintf fmt "@[\"kind\": {@\n\ \"type\": \"string\",@\n\ diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml index b5a0d3bb..68d6866a 100644 --- a/compiler/plugins/lazy_interp.ml +++ b/compiler/plugins/lazy_interp.ml @@ -243,7 +243,7 @@ let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : (Bindlib.box EEmptyError, Expr.with_ty m ty_out) ty_in (Expr.mark_pos m) | ty -> Expr.evar (Var.make "undefined_input") (Expr.with_ty m ty)) - (snd (StructName.Map.find scope_arg_struct ctx.ctx_structs))) + (StructName.Map.find scope_arg_struct ctx.ctx_structs)) m in let e_app = Expr.eapp (Expr.box e) [application_arg] m in diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index e8af9a36..90fba913 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -42,8 +42,8 @@ let rec format_expr | EVar v -> Format.fprintf fmt "%a" format_var_name v | EFunc v -> Format.fprintf fmt "%a" format_func_name v | EStruct (es, s) -> - let path, fields = StructName.Map.find s decl_ctx.ctx_structs in - Format.fprintf fmt "@[%a%a@ %a%a%a@]" Print.path path + let fields = StructName.Map.find s decl_ctx.ctx_structs in + Format.fprintf fmt "@[%a@ %a%a%a@]" StructName.format s Print.punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") @@ -142,15 +142,15 @@ let rec format_statement (format_expr decl_ctx ~debug) (naked_expr, Mark.get stmt) | SSwitch (e_switch, enum, arms) -> - let path, cons = EnumName.Map.find enum decl_ctx.ctx_enums in + let cons = EnumName.Map.find enum decl_ctx.ctx_enums in Format.fprintf fmt "@[%a @[%a@]%a@]%a" Print.keyword "switch" (format_expr decl_ctx ~debug) e_switch Print.punctuation ":" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt ((case, _), (arm_block, payload_name)) -> - Format.fprintf fmt "%a %a%a%a@ %a @[%a@ %a@]" Print.punctuation - "|" Print.path path Print.enum_constructor case Print.punctuation + Format.fprintf fmt "%a %a%a@ %a @[%a@ %a@]" Print.punctuation + "|" Print.enum_constructor case Print.punctuation ":" format_var_name payload_name Print.punctuation "→" (format_block decl_ctx ~debug) arm_block)) diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index ebc91349..ae8cff2f 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -274,8 +274,8 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : | EVar v -> format_var fmt v | EFunc f -> format_func_name fmt f | EStruct (es, s) -> - let path, fields = StructName.Map.find s ctx.ctx_structs in - Format.fprintf fmt "%a%a(%a)" Print.path path format_struct_name s + let fields = StructName.Map.find s ctx.ctx_structs in + Format.fprintf fmt "%a(%a)" format_struct_name s (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (e, (struct_field, _)) -> @@ -426,7 +426,7 @@ let rec format_statement (format_block ctx) case_none format_var case_some_var format_var tmp_var (format_block ctx) case_some | SSwitch (e1, e_name, cases) -> - let path, cons_map = EnumName.Map.find e_name ctx.ctx_enums in + let cons_map = EnumName.Map.find e_name ctx.ctx_enums in let cases = List.map2 (fun (x, y) (cons, _) -> x, y, cons) @@ -439,8 +439,8 @@ let rec format_statement (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") (fun fmt (case_block, payload_var, cons_name) -> - Format.fprintf fmt "%a.code == %a%a_Code.%a:@\n%a = %a.value@\n%a" - format_var tmp_var Print.path path format_enum_name e_name + Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" + format_var tmp_var format_enum_name e_name format_enum_cons_name cons_name format_var payload_var format_var tmp_var (format_block ctx) case_block)) cases @@ -585,10 +585,10 @@ let format_ctx match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> Format.fprintf fmt "%a@\n@\n" format_struct_decl - (s, snd (StructName.Map.find s ctx.ctx_structs)) + (s, StructName.Map.find s ctx.ctx_structs) | Scopelang.Dependency.TVertex.Enum e -> Format.fprintf fmt "%a@\n@\n" format_enum_decl - (e, snd (EnumName.Map.find e ctx.ctx_enums))) + (e, EnumName.Map.find e ctx.ctx_enums)) (type_ordering @ scope_structs) let format_program diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index d57c9661..dc9443fc 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -41,7 +41,7 @@ let rec locations_used (e : 'm expr) : LocationSet.t = type 'm rule = | Definition of location Mark.pos * typ * Desugared.Ast.io * 'm expr | Assertion of 'm expr - | Call of (path * ScopeName.t) * SubScopeName.t * 'm mark + | Call of ScopeName.t * SubScopeName.t * 'm mark type 'm scope_decl = { scope_decl_name : ScopeName.t; diff --git a/compiler/scopelang/ast.mli b/compiler/scopelang/ast.mli index 73a8df7b..3628946f 100644 --- a/compiler/scopelang/ast.mli +++ b/compiler/scopelang/ast.mli @@ -34,7 +34,7 @@ val locations_used : 'm expr -> LocationSet.t type 'm rule = | Definition of location Mark.pos * typ * Desugared.Ast.io * 'm expr | Assertion of 'm expr - | Call of (path * ScopeName.t) * SubScopeName.t * 'm mark + | Call of ScopeName.t * SubScopeName.t * 'm mark type 'm scope_decl = { scope_decl_name : ScopeName.t; diff --git a/compiler/scopelang/dependency.ml b/compiler/scopelang/dependency.ml index 6aeae9b7..0895b8c1 100644 --- a/compiler/scopelang/dependency.ml +++ b/compiler/scopelang/dependency.ml @@ -82,9 +82,9 @@ let rec expr_used_defs e = e VMap.empty in match e with - | ELocation (ToplevelVar { path = []; name = v, pos }), _ -> + | ELocation (ToplevelVar { name = v, pos }), _ -> VMap.singleton (Topdef v) pos - | (EScopeCall { path = []; scope; _ }, m) as e -> + | (EScopeCall { scope; _ }, m) as e -> VMap.add (Scope scope) (Expr.mark_pos m) (recurse_subterms e) | EAbs { binder; _ }, _ -> let _, body = Bindlib.unmbind binder in @@ -96,9 +96,10 @@ let rule_used_defs = function (* TODO: maybe this info could be passed on from previous passes without walking through all exprs again *) expr_used_defs e - | Ast.Call ((_ :: _path, _), _, _) -> VMap.empty - | Ast.Call (([], subscope), subindex, _) -> - VMap.singleton (Scope subscope) (Mark.get (SubScopeName.get_info subindex)) + | Ast.Call (subscope, subindex, _) -> + if ScopeName.path subscope = [] then + VMap.singleton (Scope subscope) (Mark.get (SubScopeName.get_info subindex)) + else VMap.empty let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t = let g = SDependencies.empty in @@ -272,7 +273,7 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t let g = TDependencies.empty in let g = StructName.Map.fold - (fun s (path, fields) g -> + (fun s fields g -> StructField.Map.fold (fun _ typ g -> let def = TVertex.Struct s in @@ -282,9 +283,9 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t (fun used g -> if TVertex.equal used def then Message.raise_spanned_error (Mark.get typ) - "The type %a%a is defined using itself, which is forbidden \ + "The type %a is defined using itself, which is forbidden \ since Catala does not provide recursive types" - Print.path path TVertex.format used + TVertex.format used else let edge = TDependencies.E.create used (Mark.get typ) def in TDependencies.add_edge_e g edge) @@ -294,7 +295,7 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t in let g = EnumName.Map.fold - (fun e (path, cases) g -> + (fun e cases g -> EnumConstructor.Map.fold (fun _ typ g -> let def = TVertex.Enum e in @@ -304,9 +305,9 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t (fun used g -> if TVertex.equal used def then Message.raise_spanned_error (Mark.get typ) - "The type %a%a is defined using itself, which is forbidden \ + "The type %a is defined using itself, which is forbidden \ since Catala does not provide recursive types" - Print.path path TVertex.format used + TVertex.format used else let edge = TDependencies.E.create used (Mark.get typ) def in TDependencies.add_edge_e g edge) diff --git a/compiler/scopelang/dependency.mli b/compiler/scopelang/dependency.mli index 970b7e2c..e37fd7ac 100644 --- a/compiler/scopelang/dependency.mli +++ b/compiler/scopelang/dependency.mli @@ -39,7 +39,7 @@ module TVertex : sig type t = Struct of StructName.t | Enum of EnumName.t val format : Format.formatter -> t -> unit - val get_info : t -> StructName.info + val get_info : t -> Uid.MarkedString.info include Graph.Sig.COMPARABLE with type t := t end diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index b45f4a88..4c4fda30 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -33,15 +33,6 @@ type ctx = { modules : ctx ModuleName.Map.t; } -let rec module_ctx ctx = function - | [] -> ctx - | (modname, mpos) :: path -> ( - match ModuleName.Map.find_opt modname ctx.modules with - | None -> - Message.raise_spanned_error mpos "Module %a not found" ModuleName.format - modname - | Some ctx -> module_ctx ctx path) - let tag_with_log_entry (e : untyped Ast.expr boxed) (l : log_entry) @@ -66,16 +57,16 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed = ctx (Array.to_list vars) (Array.to_list new_vars) in Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) tys m - | ELocation (SubScopeVar { path; scope; alias; var }) -> + | ELocation (SubScopeVar { scope; alias; var }) -> (* When referring to a subscope variable in an expression, we are referring to the output, hence we take the last state. *) - let ctx = module_ctx ctx path in + let ctx = List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.modules) ctx (ScopeName.path scope) in let var = match ScopeVar.Map.find (Mark.remove var) ctx.scope_var_mapping with | WholeVar new_s_var -> Mark.copy var new_s_var | States states -> Mark.copy var (snd (List.hd (List.rev states))) in - Expr.elocation (SubScopeVar { path; scope; alias; var }) m + Expr.elocation (SubScopeVar { scope; alias; var }) m | ELocation (DesugaredScopeVar { name; state = None }) -> Expr.elocation (ScopelangScopeVar @@ -107,7 +98,7 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed = one possible matching structure *) Message.raise_spanned_error (Expr.mark_pos m) "Ambiguous structure field access" - | EDStructAccess { e; field; path = _; name_opt = Some name } -> + | EDStructAccess { e; field; name_opt = Some name } -> let e' = translate_expr ctx e in let field = try @@ -121,8 +112,8 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed = field StructName.format name in Expr.estructaccess ~e:e' ~field ~name m - | EScopeCall { path; scope; args } -> - Expr.escopecall ~path ~scope + | EScopeCall { scope; args } -> + Expr.escopecall ~scope ~args: (ScopeVar.Map.fold (fun v e args' -> @@ -624,13 +615,12 @@ let translate_rule (D.ScopeDef.Map.find def_key exc_graphs) ~is_cond ~is_subscope_var:true in - let subscop_path, subscop_real_name = + let subscop_real_name = SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes in Ast.Definition ( ( SubScopeVar { - path = subscop_path; scope = subscop_real_name; alias = sub_scope_index, var_pos; var = diff --git a/compiler/scopelang/print.ml b/compiler/scopelang/print.ml index ec25575f..e8bf316e 100644 --- a/compiler/scopelang/print.ml +++ b/compiler/scopelang/print.ml @@ -22,9 +22,9 @@ let struc ctx (fmt : Format.formatter) (name : StructName.t) - ((path, fields) : path * typ StructField.Map.t) : unit = - Format.fprintf fmt "%a %a%a %a %a@\n@[ %a@]@\n%a" Print.keyword - "struct" Print.path path StructName.format name Print.punctuation "=" + (fields : typ StructField.Map.t) : unit = + Format.fprintf fmt "%a %a %a %a@\n@[ %a@]@\n%a" Print.keyword + "struct" StructName.format name Print.punctuation "=" Print.punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") @@ -38,9 +38,9 @@ let enum ctx (fmt : Format.formatter) (name : EnumName.t) - ((path, cases) : path * typ EnumConstructor.Map.t) : unit = - Format.fprintf fmt "%a %a%a %a @\n@[ %a@]" Print.keyword "enum" - Print.path path EnumName.format name Print.punctuation "=" + (cases : typ EnumConstructor.Map.t) : unit = + Format.fprintf fmt "%a %a %a @\n@[ %a@]" Print.keyword "enum" + EnumName.format name Print.punctuation "=" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (field_name, typ) -> @@ -93,9 +93,9 @@ let scope ?debug ctx fmt (name, (decl, _pos)) = | Assertion e -> Format.fprintf fmt "%a %a" Print.keyword "assert" (Print.expr ?debug ()) e - | Call ((scope_path, scope_name), subscope_name, _) -> - Format.fprintf fmt "%a %a%a%a%a%a" Print.keyword "call" Print.path - scope_path ScopeName.format scope_name Print.punctuation "[" + | Call (scope_name, subscope_name, _) -> + Format.fprintf fmt "%a %a%a%a%a" Print.keyword "call" + ScopeName.format scope_name Print.punctuation "[" SubScopeName.format subscope_name Print.punctuation "]")) decl.scope_decl_rules diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 76674d3f..43f30884 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -22,16 +22,13 @@ open Catala_utils module Runtime = Runtime_ocaml.Runtime -module ModuleName = String -(* TODO: should probably be turned into an Uid once we implement module import - directives; that will incur an additional resolution work on all paths - though *) +module ModuleName = Uid.Module -module ScopeName = Uid.Gen () -module TopdefName = Uid.Gen () -module StructName = Uid.Gen () +module ScopeName = Uid.Gen_qualified () +module TopdefName = Uid.Gen_qualified () +module StructName = Uid.Gen_qualified () module StructField = Uid.Gen () -module EnumName = Uid.Gen () +module EnumName = Uid.Gen_qualified () module EnumConstructor = Uid.Gen () (** Only used by surface *) @@ -348,8 +345,6 @@ type lit = | LDate of date | LDuration of duration -type path = ModuleName.t Mark.pos list - (** External references are resolved to strings that point to functions or constants in the end, but we need to keep different references for typing *) type external_ref = @@ -368,14 +363,12 @@ type 'a glocation = } -> < scopeVarSimpl : yes ; .. > glocation | SubScopeVar : { - path : path; scope : ScopeName.t; alias : SubScopeName.t Mark.pos; var : ScopeVar.t Mark.pos; } -> < explicitScopes : yes ; .. > glocation | ToplevelVar : { - path : path; name : TopdefName.t Mark.pos; } -> < explicitScopes : yes ; .. > glocation @@ -456,13 +449,11 @@ and ('a, 'b, 'm) base_gexpr = (* Early stages *) | ELocation : 'b glocation -> ('a, (< .. > as 'b), 'm) base_gexpr | EScopeCall : { - path : path; scope : ScopeName.t; args : ('a, 'm) gexpr ScopeVar.Map.t; } -> ('a, < explicitScopes : yes ; .. >, 'm) base_gexpr | EDStructAccess : { - path : path; name_opt : StructName.t option; e : ('a, 'm) gexpr; field : Ident.t; @@ -478,7 +469,6 @@ and ('a, 'b, 'm) base_gexpr = (** Resolved struct/enums, after [desugared] *) (* Lambda-like *) | EExternal : { - path : path; name : external_ref Mark.pos; } -> ('a, < explicitScopes : no ; .. >, 't) base_gexpr @@ -594,8 +584,8 @@ type 'e code_item_list = | Nil | Cons of 'e code_item * ('e, 'e code_item_list) binder -type struct_ctx = (path * typ StructField.Map.t) StructName.Map.t -type enum_ctx = (path * typ EnumConstructor.Map.t) EnumName.Map.t +type struct_ctx = typ StructField.Map.t StructName.Map.t +type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t type scope_info = { in_struct_name : StructName.t; diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index efe832bd..920f4fa7 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -110,8 +110,8 @@ let subst binder vars = let evar v mark = Mark.add mark (Bindlib.box_var v) -let eexternal ~path ~name mark = - Mark.add mark (Bindlib.box (EExternal { path; name })) +let eexternal ~name mark = + Mark.add mark (Bindlib.box (EExternal { name })) let etuple args = Box.appn args @@ fun args -> ETuple args @@ -155,8 +155,8 @@ let estruct ~name ~(fields : ('a, 't) boxed_gexpr StructField.Map.t) mark = (fun fields -> EStruct { name; fields }) (Box.lift_struct (StructField.Map.map Box.lift fields)) -let edstructaccess ~path ~name_opt ~field ~e = - Box.app1 e @@ fun e -> EDStructAccess { path; name_opt; field; e } +let edstructaccess ~name_opt ~field ~e = + Box.app1 e @@ fun e -> EDStructAccess { name_opt; field; e } let estructaccess ~name ~field ~e = Box.app1 e @@ fun e -> EStructAccess { name; field; e } @@ -170,10 +170,10 @@ let ematch ~name ~e ~cases mark = (Box.lift e) (Box.lift_enum (EnumConstructor.Map.map Box.lift cases)) -let escopecall ~path ~scope ~args mark = +let escopecall ~scope ~args mark = Mark.add mark @@ Bindlib.box_apply - (fun args -> EScopeCall { path; scope; args }) + (fun args -> EScopeCall { scope; args }) (Box.lift_scope_vars (ScopeVar.Map.map Box.lift args)) (* - Manipulation of marks - *) @@ -253,7 +253,7 @@ let maybe_ty (type m) ?(typ = TAny) (m : m mark) : typ = (* - Predefined types (option) - *) -let option_enum = EnumName.fresh ("eoption", Pos.no_pos) +let option_enum = EnumName.fresh [] ("eoption", Pos.no_pos) let none_constr = EnumConstructor.fresh ("ENone", Pos.no_pos) let some_constr = EnumConstructor.fresh ("ESome", Pos.no_pos) @@ -275,7 +275,7 @@ let map | EOp { op; tys } -> eop op tys m | EArray args -> earray (List.map f args) m | EVar v -> evar (Var.translate v) m - | EExternal { path; name } -> eexternal ~path ~name m + | EExternal { name } -> eexternal ~name m | EAbs { binder; tys } -> let vars, body = Bindlib.unmbind binder in let body = f body in @@ -297,15 +297,15 @@ let map | EStruct { name; fields } -> let fields = StructField.Map.map f fields in estruct ~name ~fields m - | EDStructAccess { path; name_opt; field; e } -> - edstructaccess ~path ~name_opt ~field ~e:(f e) m + | EDStructAccess { name_opt; field; e } -> + edstructaccess ~name_opt ~field ~e:(f e) m | EStructAccess { name; field; e } -> estructaccess ~name ~field ~e:(f e) m | EMatch { name; e; cases } -> let cases = EnumConstructor.Map.map f cases in ematch ~name ~e:(f e) ~cases m - | EScopeCall { path; scope; args } -> + | EScopeCall { scope; args } -> let args = ScopeVar.Map.map f args in - escopecall ~path ~scope ~args m + escopecall ~scope ~args m | ECustom { obj; targs; tret } -> ecustom obj targs tret m let rec map_top_down ~f e = map ~f:(map_top_down ~f) (f e) @@ -372,7 +372,7 @@ let map_gather let acc, args = lfoldmap args in acc, earray args m | EVar v -> acc, evar (Var.translate v) m - | EExternal { path; name } -> acc, eexternal ~path ~name m + | EExternal { name } -> acc, eexternal ~name m | EAbs { binder; tys } -> let vars, body = Bindlib.unmbind binder in let acc, body = f body in @@ -420,9 +420,9 @@ let map_gather (acc, StructField.Map.empty) in acc, estruct ~name ~fields m - | EDStructAccess { path; name_opt; field; e } -> + | EDStructAccess { name_opt; field; e } -> let acc, e = f e in - acc, edstructaccess ~path ~name_opt ~field ~e m + acc, edstructaccess ~name_opt ~field ~e m | EStructAccess { name; field; e } -> let acc, e = f e in acc, estructaccess ~name ~field ~e m @@ -437,7 +437,7 @@ let map_gather (acc, EnumConstructor.Map.empty) in acc, ematch ~name ~e ~cases m - | EScopeCall { path; scope; args } -> + | EScopeCall { scope; args } -> let acc, args = ScopeVar.Map.fold (fun var e (acc, args) -> @@ -445,7 +445,7 @@ let map_gather join acc acc1, ScopeVar.Map.add var e args) args (acc, ScopeVar.Map.empty) in - acc, escopecall ~path ~scope ~args m + acc, escopecall ~scope ~args m | ECustom { obj; targs; tret } -> acc, ecustom obj targs tret m (* - *) @@ -518,8 +518,6 @@ let compare_lit (l1 : lit) (l2 : lit) = | LDuration _, _ -> . | _, LDuration _ -> . -let compare_path = List.compare (Mark.compare ModuleName.compare) - let compare_location (type a) (x : a glocation Mark.pos) @@ -542,9 +540,9 @@ let compare_location SubScopeVar { alias = ysubindex, _; var = ysubvar, _; _ } ) -> let c = SubScopeName.compare xsubindex ysubindex in if c = 0 then ScopeVar.compare xsubvar ysubvar else c - | ( ToplevelVar { path = px; name = vx, _ }, - ToplevelVar { path = py; name = vy, _ } ) -> ( - match compare_path px py with 0 -> TopdefName.compare vx vy | n -> n) + | ( ToplevelVar { name = vx, _ }, + ToplevelVar { name = vy, _ } ) -> + TopdefName.compare vx vy | DesugaredScopeVar _, _ -> -1 | _, DesugaredScopeVar _ -> 1 | ScopelangScopeVar _, _ -> -1 @@ -554,7 +552,6 @@ let compare_location | ToplevelVar _, _ -> . | _, ToplevelVar _ -> . -let equal_path = List.equal (Mark.equal ModuleName.equal) let equal_location a b = compare_location a b = 0 let equal_except ex1 ex2 = ex1 = ex2 let compare_except ex1 ex2 = Stdlib.compare ex1 ex2 @@ -583,8 +580,8 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = fun e1 e2 -> match Mark.remove e1, Mark.remove e2 with | EVar v1, EVar v2 -> Bindlib.eq_vars v1 v2 - | EExternal { path = p1; name = n1 }, EExternal { path = p2; name = n2 } -> - Mark.equal equal_external_ref n1 n2 && equal_path p1 p2 + | EExternal { name = n1 }, EExternal { name = n2 } -> + Mark.equal equal_external_ref n1 n2 | ETuple es1, ETuple es2 -> equal_list es1 es2 | ( ETupleAccess { e = e1; index = id1; size = s1 }, ETupleAccess { e = e2; index = id2; size = s2 } ) -> @@ -615,10 +612,9 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = | ( EStruct { name = s1; fields = fields1 }, EStruct { name = s2; fields = fields2 } ) -> StructName.equal s1 s2 && StructField.Map.equal equal fields1 fields2 - | ( EDStructAccess { e = e1; field = f1; name_opt = s1; path = p1 }, - EDStructAccess { e = e2; field = f2; name_opt = s2; path = p2 } ) -> + | ( EDStructAccess { e = e1; field = f1; name_opt = s1 }, + EDStructAccess { e = e2; field = f2; name_opt = s2 } ) -> Option.equal StructName.equal s1 s2 - && equal_path p1 p2 && Ident.equal f1 f2 && equal e1 e2 | ( EStructAccess { e = e1; field = f1; name = s1 }, @@ -632,10 +628,9 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = EnumName.equal n1 n2 && equal e1 e2 && EnumConstructor.Map.equal equal cases1 cases2 - | ( EScopeCall { path = p1; scope = s1; args = fields1 }, - EScopeCall { path = p2; scope = s2; args = fields2 } ) -> + | ( EScopeCall { scope = s1; args = fields1 }, + EScopeCall { scope = s2; args = fields2 } ) -> ScopeName.equal s1 s2 - && equal_path p1 p2 && ScopeVar.Map.equal equal fields1 fields2 | ( ECustom { obj = obj1; targs = targs1; tret = tret1 }, ECustom { obj = obj2; targs = targs2; tret = tret2 } ) -> @@ -667,8 +662,8 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int = List.compare compare a1 a2 | EVar v1, EVar v2 -> Bindlib.compare_vars v1 v2 - | EExternal { path = p1; name = n1 }, EExternal { path = p2; name = n2 } -> - compare_path p1 p2 @@< fun () -> Mark.compare compare_external_ref n1 n2 + | EExternal { name = n1 }, EExternal { name = n2 } -> + Mark.compare compare_external_ref n1 n2 | EAbs {binder=binder1; tys=typs1}, EAbs {binder=binder2; tys=typs2} -> List.compare Type.compare typs1 typs2 @@< fun () -> @@ -685,10 +680,9 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int = EStruct {name=name2; fields=field_map2 } -> StructName.compare name1 name2 @@< fun () -> StructField.Map.compare compare field_map1 field_map2 - | EDStructAccess {e=e1; field=field_name1; name_opt=struct_name1; path=p1}, - EDStructAccess {e=e2; field=field_name2; name_opt=struct_name2; path=p2} -> + | EDStructAccess {e=e1; field=field_name1; name_opt=struct_name1}, + EDStructAccess {e=e2; field=field_name2; name_opt=struct_name2} -> compare e1 e2 @@< fun () -> - compare_path p1 p2 @@< fun () -> Ident.compare field_name1 field_name2 @@< fun () -> Option.compare StructName.compare struct_name1 struct_name2 | EStructAccess {e=e1; field=field_name1; name=struct_name1 }, @@ -701,9 +695,8 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int = EnumName.compare name1 name2 @@< fun () -> compare e1 e2 @@< fun () -> EnumConstructor.Map.compare compare emap1 emap2 - | EScopeCall {path = p1; scope=name1; args=field_map1}, - EScopeCall {path = p2; scope=name2; args=field_map2} -> - compare_path p1 p2 @@< fun () -> + | EScopeCall {scope=name1; args=field_map1}, + EScopeCall {scope=name2; args=field_map2} -> ScopeName.compare name1 name2 @@< fun () -> ScopeVar.Map.compare compare field_map1 field_map2 | ETuple es1, ETuple es2 -> diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index 01f3bc6d..b485d396 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -38,7 +38,6 @@ val rebox : ('a any, 'm) gexpr -> ('a, 'm) boxed_gexpr val evar : ('a, 'm) gexpr Var.t -> 'm mark -> ('a, 'm) boxed_gexpr val eexternal : - path:path -> name:external_ref Mark.pos -> 'm mark -> (< explicitScopes : no ; .. >, 'm) boxed_gexpr @@ -119,7 +118,6 @@ val estruct : ('a any, 'm) boxed_gexpr val edstructaccess : - path:path -> name_opt:StructName.t option -> field:Ident.t -> e:('a, 'm) boxed_gexpr -> @@ -148,7 +146,6 @@ val ematch : ('a any, 'm) boxed_gexpr val escopecall : - path:path -> scope:ScopeName.t -> args:('a, 'm) boxed_gexpr ScopeVar.Map.t -> 'm mark -> @@ -389,8 +386,6 @@ val format : Format.formatter -> ('a, 'm) gexpr -> unit val equal_lit : lit -> lit -> bool val compare_lit : lit -> lit -> int -val equal_path : path -> path -> bool -val compare_path : path -> path -> int val equal_location : 'a glocation Mark.pos -> 'a glocation Mark.pos -> bool val compare_location : 'a glocation Mark.pos -> 'a glocation Mark.pos -> int val equal_except : except -> except -> bool diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index 0d078eaf..1b53676e 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -448,7 +448,6 @@ let rec runtime_to_val : m ) | TStruct name -> StructName.Map.find name ctx.ctx_structs - |> snd |> StructField.Map.to_seq |> Seq.map2 (fun o (fld, ty) -> fld, runtime_to_val eval_expr ctx m ty o) @@ -459,7 +458,7 @@ let rec runtime_to_val : (* we only use non-constant constructors of arity 1, which allows us to always use the tag directly (ordered as declared in the constr map), and the field 0 *) - let _path, cons_map = EnumName.Map.find name ctx.ctx_enums in + let cons_map = EnumName.Map.find name ctx.ctx_enums in let cons, ty = List.nth (EnumConstructor.Map.bindings cons_map) @@ -497,7 +496,7 @@ and val_to_runtime : List.map2 (val_to_runtime eval_expr ctx) ts es |> Array.of_list |> Obj.repr | TStruct name1, EStruct { name; fields } -> assert (StructName.equal name name1); - let _path, fld_tys = StructName.Map.find name ctx.ctx_structs in + let fld_tys = StructName.Map.find name ctx.ctx_structs in Seq.map2 (fun (_, ty) (_, v) -> val_to_runtime eval_expr ctx ty v) (StructField.Map.to_seq fld_tys) @@ -506,7 +505,7 @@ and val_to_runtime : |> Obj.repr | TEnum name1, EInj { name; cons; e } -> assert (EnumName.equal name name1); - let _path, cons_map = EnumName.Map.find name ctx.ctx_enums in + let cons_map = EnumName.Map.find name ctx.ctx_enums in let rec find_tag n = function | [] -> assert false | (c, ty) :: _ when EnumConstructor.equal c cons -> n, ty @@ -549,7 +548,11 @@ let rec evaluate_expr : Message.raise_spanned_error pos "free variable found at evaluation (should not happen if term was \ well-typed)" - | EExternal { path; name } -> + | EExternal { name } -> + let path = match Mark.remove name with + | External_value td -> TopdefName.path td + | External_scope s -> ScopeName.path s + in let ty = try let ctx = Program.module_ctx ctx path in @@ -563,11 +566,11 @@ let rec evaluate_expr : pos ) with TopdefName.Map.Not_found _ | ScopeName.Map.Not_found _ -> Message.raise_spanned_error pos - "Reference to %a%a could not be resolved" Print.path path + "Reference to %a could not be resolved" Print.external_ref name in let runtime_path = - ( List.map Mark.remove path, + ( List.map ModuleName.to_string path, match Mark.remove name with | External_value name -> Mark.remove (TopdefName.get_info name) | External_scope name -> Mark.remove (ScopeName.get_info name) ) @@ -814,7 +817,7 @@ let interpret_program_lcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list the types of the scope arguments. For [context] arguments, we can provide an empty thunked term. But for [input] arguments of another type, we cannot provide anything so we have to fail. *) - let _path, taus = StructName.Map.find s_in ctx.ctx_structs in + let taus = StructName.Map.find s_in ctx.ctx_structs in let application_term = StructField.Map.map (fun ty -> @@ -864,7 +867,7 @@ let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list the types of the scope arguments. For [context] arguments, we can provide an empty thunked term. But for [input] arguments of another type, we cannot provide anything so we have to fail. *) - let _path, taus = StructName.Map.find s_in ctx.ctx_structs in + let taus = StructName.Map.find s_in ctx.ctx_structs in let application_term = StructField.Map.map (fun ty -> diff --git a/compiler/shared_ast/optimizations.ml b/compiler/shared_ast/optimizations.ml index c614dfeb..d1ee6a6a 100644 --- a/compiler/shared_ast/optimizations.ml +++ b/compiler/shared_ast/optimizations.ml @@ -347,7 +347,7 @@ let optimize_program (p : 'm program) : 'm program = let test_iota_reduction_1 () = let x = Var.make "x" in - let enumT = EnumName.fresh ("t", Pos.no_pos) in + let enumT = EnumName.fresh [] ("t", Pos.no_pos) in let consA = EnumConstructor.fresh ("A", Pos.no_pos) in let consB = EnumConstructor.fresh ("B", Pos.no_pos) in let consC = EnumConstructor.fresh ("C", Pos.no_pos) in @@ -387,7 +387,7 @@ let cases_of_list l : ('a, 't) boxed_gexpr EnumConstructor.Map.t = (Untyped { pos = Pos.no_pos }) )) let test_iota_reduction_2 () = - let enumT = EnumName.fresh ("t", Pos.no_pos) in + let enumT = EnumName.fresh [] ("t", Pos.no_pos) in let consA = EnumConstructor.fresh ("A", Pos.no_pos) in let consB = EnumConstructor.fresh ("B", Pos.no_pos) in let consC = EnumConstructor.fresh ("C", Pos.no_pos) in diff --git a/compiler/shared_ast/print.ml b/compiler/shared_ast/print.ml index 903dbf4f..9bc428a7 100644 --- a/compiler/shared_ast/print.ml +++ b/compiler/shared_ast/print.ml @@ -86,8 +86,7 @@ let location (type a) (fmt : Format.formatter) (l : a glocation) : unit = | SubScopeVar { alias = subindex; var = subvar; _ } -> Format.fprintf fmt "%a.%a" SubScopeName.format (Mark.remove subindex) ScopeVar.format (Mark.remove subvar) - | ToplevelVar { path = p; name } -> - path fmt p; + | ToplevelVar { name } -> TopdefName.format fmt (Mark.remove name) let enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit = @@ -131,12 +130,11 @@ let rec typ_gen match ctx with | None -> StructName.format fmt s | Some ctx -> - let p, fields = StructName.Map.find s ctx.ctx_structs in + let fields = StructName.Map.find s ctx.ctx_structs in if StructField.Map.is_empty fields then ( - path fmt p; StructName.format fmt s) else - Format.fprintf fmt "@[%a%a %a@,%a@;<0 -2>%a@]" path p + Format.fprintf fmt "@[%a %a@,%a@;<0 -2>%a@]" StructName.format s (pp_color_string (List.hd colors)) "{" @@ -156,8 +154,8 @@ let rec typ_gen match ctx with | None -> Format.fprintf fmt "@[%a@]" EnumName.format e | Some ctx -> - let p, def = EnumName.Map.find e ctx.ctx_enums in - Format.fprintf fmt "@[%a%a%a%a%a@]" path p EnumName.format e + let def = EnumName.Map.find e ctx.ctx_enums in + Format.fprintf fmt "@[%a%a%a%a@]" EnumName.format e punctuation "[" (EnumConstructor.Map.format_bindings ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " punctuation "|") @@ -519,8 +517,7 @@ module ExprGen (C : EXPR_PARAM) = struct else match Mark.remove e with | EVar v -> var fmt v - | EExternal { path = p; name } -> - path fmt p; + | EExternal { name } -> external_ref fmt name | ETuple es -> Format.fprintf fmt "@[%a%a%a@]" @@ -718,9 +715,8 @@ module ExprGen (C : EXPR_PARAM) = struct Format.fprintf fmt "@[%a %t@ %a@ %a@]" punctuation "|" pp_cons_name punctuation "→" (rhs exprc) e)) cases - | EScopeCall { path = scope_path; scope; args } -> + | EScopeCall { scope; args } -> Format.pp_open_hovbox fmt 2; - path fmt scope_path; ScopeName.format fmt scope; Format.pp_print_space fmt (); keyword fmt "of"; @@ -862,8 +858,8 @@ let enum decl_ctx fmt (pp_name : Format.formatter -> unit) - ((p, c) : path * typ EnumConstructor.Map.t) = - Format.fprintf fmt "@[%a %a%t %a@ %a@]" keyword "type" path p pp_name + (c : typ EnumConstructor.Map.t) = + Format.fprintf fmt "@[%a %t %a@ %a@]" keyword "type" pp_name punctuation "=" (EnumConstructor.Map.format_bindings ~pp_sep:(fun _ _ -> ()) @@ -879,9 +875,9 @@ let struct_ decl_ctx fmt (pp_name : Format.formatter -> unit) - ((p, c) : path * typ StructField.Map.t) = - Format.fprintf fmt "@[@[@[%a %a%t %a@;%a@]@;%a@]%a@]@;" keyword - "type" path p pp_name punctuation "=" punctuation "{" + (c : typ StructField.Map.t) = + Format.fprintf fmt "@[@[@[%a %t %a@;%a@]@;%a@]%a@]@;" keyword + "type" pp_name punctuation "=" punctuation "{" (StructField.Map.format_bindings ~pp_sep:(fun _ _ -> ()) (fun fmt pp_n ty -> diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index 20fce92e..728c73d7 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -15,7 +15,6 @@ License for the specific language governing permissions and limitations under the License. *) -open Catala_utils open Definitions let map_exprs ~f ~varf { code_items; decl_ctx } = @@ -39,14 +38,9 @@ let empty_ctx = ctx_modules = ModuleName.Map.empty; } -let rec module_ctx ctx = function - | [] -> ctx - | (modname, mpos) :: path -> ( - match ModuleName.Map.find_opt modname ctx.ctx_modules with - | None -> - Message.raise_spanned_error mpos "Module %a not found" ModuleName.format - modname - | Some ctx -> module_ctx ctx path) +let module_ctx ctx path = + List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.ctx_modules) + ctx path let get_scope_body { code_items; _ } scope = match diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index dda3553e..58a8f55b 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -22,9 +22,9 @@ open Definitions val empty_ctx : decl_ctx -val module_ctx : decl_ctx -> ModuleName.t Mark.pos list -> decl_ctx +val module_ctx : decl_ctx -> Uid.Path.t -> decl_ctx (** Follows a path to get the corresponding context for type and value - declarations. Errors out if the module is not found *) + declarations. *) (** {2 Transformations} *) diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index 9c20358a..d1770225 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -130,10 +130,8 @@ let rec format_typ (pp_color_string (List.hd colors)) ")" | TStruct s -> - Print.path fmt (fst (A.StructName.Map.find s ctx.A.ctx_structs)); A.StructName.format fmt s | TEnum e -> - Print.path fmt (fst (A.EnumName.Map.find e ctx.A.ctx_enums)); A.EnumName.format fmt e | TOption t -> Format.fprintf fmt "@[option %a@]" @@ -325,11 +323,11 @@ module Env = struct { structs = A.StructName.Map.map - (fun (_path, ty) -> A.StructField.Map.map ast_to_typ ty) + (fun ty -> A.StructField.Map.map ast_to_typ ty) decl_ctx.ctx_structs; enums = A.EnumName.Map.map - (fun (_path, ty) -> A.EnumConstructor.Map.map ast_to_typ ty) + (fun ty -> A.EnumConstructor.Map.map ast_to_typ ty) decl_ctx.ctx_enums; vars = Var.Map.empty; scope_vars = A.ScopeVar.Map.empty; @@ -347,14 +345,7 @@ module Env = struct A.ScopeVar.Map.find_opt var vmap) let rec module_env path env = - match path with - | [] -> env - | (modname, mpos) :: path -> ( - match A.ModuleName.Map.find_opt modname env.modules with - | None -> - Message.raise_spanned_error mpos "Module %a not found" - A.ModuleName.format modname - | Some env -> module_env path env) + List.fold_left (fun env m -> A.ModuleName.Map.find m env.modules) env path let add v tau t = { t with vars = Var.Map.add v tau t.vars } let add_var v typ t = add v (ast_to_typ typ) t @@ -435,11 +426,11 @@ and typecheck_expr_top_down : match loc with | DesugaredScopeVar { name; _ } | ScopelangScopeVar { name } -> Env.get_scope_var env (Mark.remove name) - | SubScopeVar { path; scope; var; _ } -> - let env = Env.module_env path env in + | SubScopeVar { scope; var; _ } -> + let env = Env.module_env (A.ScopeName.path scope) env in Env.get_subscope_out_var env scope (Mark.remove var) - | ToplevelVar { path; name } -> - let env = Env.module_env path env in + | ToplevelVar { name } -> + let env = Env.module_env (A.TopdefName.path (Mark.remove name)) env in Env.get_toplevel_var env (Mark.remove name) in let ty = @@ -452,7 +443,7 @@ and typecheck_expr_top_down : Expr.elocation loc (mark_with_tau_and_unify (ast_to_typ ty)) | A.EStruct { name; fields } -> let mark = ty_mark (TStruct name) in - let _path, str_ast = A.StructName.Map.find name ctx.A.ctx_structs in + let str_ast = A.StructName.Map.find name ctx.A.ctx_structs in let str = A.StructName.Map.find name env.structs in let _check_fields : unit = let missing_fields, extra_fields = @@ -493,7 +484,7 @@ and typecheck_expr_top_down : fields in Expr.estruct ~name ~fields mark - | A.EDStructAccess { e = e_struct; path = _; name_opt; field } -> + | A.EDStructAccess { e = e_struct; name_opt; field } -> let t_struct = match name_opt with | Some name -> TStruct name @@ -514,7 +505,6 @@ and typecheck_expr_top_down : "This is not a structure, cannot access field %s (%a)" field (format_typ ctx) (ty e_struct') in - let path, _ = A.StructName.Map.find name ctx.ctx_structs in let fld_ty = let str = try A.StructName.Map.find name env.structs @@ -549,7 +539,7 @@ and typecheck_expr_top_down : A.StructField.Map.find field str in let mark = mark_with_tau_and_unify fld_ty in - Expr.edstructaccess ~e:e_struct' ~path ~name_opt:(Some name) ~field mark + Expr.edstructaccess ~e:e_struct' ~name_opt:(Some name) ~field mark | A.EStructAccess { e = e_struct; name; field } -> let fld_ty = let str = @@ -628,7 +618,7 @@ and typecheck_expr_top_down : in Expr.ematch ~e:e1' ~name ~cases mark | A.EMatch { e = e1; name; cases } -> - let _path, cases_ty = A.EnumName.Map.find name ctx.A.ctx_enums in + let cases_ty = A.EnumName.Map.find name ctx.A.ctx_enums in let t_ret = unionfind ~pos:e1 (TAny (Any.fresh ())) in let mark = mark_with_tau_and_unify t_ret in let e1' = @@ -647,7 +637,8 @@ and typecheck_expr_top_down : cases in Expr.ematch ~e:e1' ~name ~cases mark - | A.EScopeCall { path; scope; args } -> + | A.EScopeCall { scope; args } -> + let path = A.ScopeName.path scope in let scope_out_struct = let ctx = Program.module_ctx ctx path in (A.ScopeName.Map.find scope ctx.ctx_scopes).out_struct_name @@ -664,7 +655,7 @@ and typecheck_expr_top_down : (ast_to_typ (A.ScopeVar.Map.find name vars))) args in - Expr.escopecall ~path ~scope ~args:args' mark + Expr.escopecall ~scope ~args:args' mark | A.ERaise ex -> Expr.eraise ex context_mark | A.ECatch { body; exn; handler } -> let body' = typecheck_expr_top_down ~leave_unresolved ctx env tau body in @@ -681,14 +672,18 @@ and typecheck_expr_top_down : "Variable %s not found in the current context" (Bindlib.name_of v) in Expr.evar (Var.translate v) (mark_with_tau_and_unify tau') - | A.EExternal { path; name } -> + | A.EExternal { name } -> + let path = match Mark.remove name with + | External_value td -> A.TopdefName.path td + | External_scope s -> A.ScopeName.path s + in let ctx = Program.module_ctx ctx path in let ty = let not_found pr x = Message.raise_spanned_error pos_e - "Could not resolve the reference to %a%a.@ Make sure the \ + "Could not resolve the reference to %a.@ Make sure the \ corresponding module was properly loaded?" - Print.path path pr x + pr x in match Mark.remove name with | A.External_value name -> ( @@ -705,7 +700,7 @@ and typecheck_expr_top_down : pos_e ) with A.ScopeName.Map.Not_found _ -> not_found A.ScopeName.format name) in - Expr.eexternal ~path ~name (mark_with_tau_and_unify ty) + Expr.eexternal ~name (mark_with_tau_and_unify ty) | A.ELit lit -> Expr.elit lit (ty_mark (lit_type lit)) | A.ETuple es -> let tys = List.map (fun _ -> unionfind (TAny (Any.fresh ()))) es in @@ -1031,9 +1026,8 @@ let program ~leave_unresolved prg = prg.decl_ctx with ctx_structs = A.StructName.Map.mapi - (fun s_name (path, fields) -> - ( path, - A.StructField.Map.mapi + (fun s_name (fields) -> + ( A.StructField.Map.mapi (fun f_name (t : A.typ) -> match Mark.remove t with | TAny -> @@ -1045,9 +1039,8 @@ let program ~leave_unresolved prg = prg.decl_ctx.ctx_structs; ctx_enums = A.EnumName.Map.mapi - (fun e_name (path, cons) -> - ( path, - A.EnumConstructor.Map.mapi + (fun e_name (cons) -> + ( A.EnumConstructor.Map.mapi (fun cons_name (t : A.typ) -> match Mark.remove t with | TAny -> diff --git a/compiler/shared_ast/typing.mli b/compiler/shared_ast/typing.mli index 7de65aad..836bd598 100644 --- a/compiler/shared_ast/typing.mli +++ b/compiler/shared_ast/typing.mli @@ -17,6 +17,7 @@ (** Typing for the default calculus. Because of the error terms, we perform type inference using the classical W algorithm with union-find unification. *) +open Catala_utils open Definitions module Env : sig @@ -28,7 +29,7 @@ module Env : sig val add_scope_var : ScopeVar.t -> typ -> 'e t -> 'e t val add_scope : ScopeName.t -> vars:typ ScopeVar.Map.t -> 'e t -> 'e t val add_module : ModuleName.t -> module_env:'e t -> 'e t -> 'e t - val module_env : path -> 'e t -> 'e t + val module_env : Uid.Path.t -> 'e t -> 'e t val open_scope : ScopeName.t -> 'e t -> 'e t end diff --git a/compiler/verification/z3backend.real.ml b/compiler/verification/z3backend.real.ml index a3ab4f5c..1aae7076 100644 --- a/compiler/verification/z3backend.real.ml +++ b/compiler/verification/z3backend.real.ml @@ -162,7 +162,7 @@ let rec print_z3model_expr (ctx : context) (ty : typ) (e : Expr.expr) : string = match Mark.remove ty with | TLit ty -> print_lit ty | TStruct name -> - let _path, s = StructName.Map.find name ctx.ctx_decl.ctx_structs in + let s = StructName.Map.find name ctx.ctx_decl.ctx_structs in let get_fieldname (fn : StructField.t) : string = Mark.remove (StructField.get_info fn) in @@ -188,7 +188,7 @@ let rec print_z3model_expr (ctx : context) (ty : typ) (e : Expr.expr) : string = let fd = Expr.get_func_decl e in let fd_name = Symbol.to_string (FuncDecl.get_name fd) in - let _path, enum_ctrs = EnumName.Map.find name ctx.ctx_decl.ctx_enums in + let enum_ctrs = EnumName.Map.find name ctx.ctx_decl.ctx_enums in let case = List.find (fun (ctr, _) -> @@ -315,7 +315,7 @@ and find_or_create_enum (ctx : context) (enum : EnumName.t) : match EnumName.Map.find_opt enum ctx.ctx_z3datatypes with | Some e -> ctx, e | None -> - let _path, ctrs = EnumName.Map.find enum ctx.ctx_decl.ctx_enums in + let ctrs = EnumName.Map.find enum ctx.ctx_decl.ctx_enums in let ctx, z3_ctrs = EnumConstructor.Map.fold (fun ctr ty (ctx, ctrs) -> @@ -340,7 +340,7 @@ and find_or_create_struct (ctx : context) (s : StructName.t) : | Some s -> ctx, s | None -> let s_name = Mark.remove (StructName.get_info s) in - let _path, fields = StructName.Map.find s ctx.ctx_decl.ctx_structs in + let fields = StructName.Map.find s ctx.ctx_decl.ctx_structs in let z3_fieldnames = List.map (fun f -> @@ -666,7 +666,7 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr = mk_struct. The accessors of this constructor correspond to the field accesses *) let accessors = List.hd (Datatype.get_accessors z3_struct) in - let _path, fields = StructName.Map.find name ctx.ctx_decl.ctx_structs in + let fields = StructName.Map.find name ctx.ctx_decl.ctx_structs in let idx_mappings = List.combine (StructField.Map.keys fields) accessors in let _, accessor = List.find (fun (field1, _) -> StructField.equal field field1) idx_mappings @@ -681,7 +681,7 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr = let ctx, z3_enum = find_or_create_enum ctx name in let ctx, z3_arg = translate_expr ctx e in let ctrs = Datatype.get_constructors z3_enum in - let _path, cons_map = EnumName.Map.find name ctx.ctx_decl.ctx_enums in + let cons_map = EnumName.Map.find name ctx.ctx_decl.ctx_enums in (* This should always succeed if the expression is well-typed in dcalc *) let idx_mappings = List.combine (EnumConstructor.Map.keys cons_map) ctrs in let _, ctr = diff --git a/tests/test_modules/good/mod_use.catala_en b/tests/test_modules/good/mod_use.catala_en index 976dca20..4d4cd751 100644 --- a/tests/test_modules/good/mod_use.catala_en +++ b/tests/test_modules/good/mod_use.catala_en @@ -12,6 +12,9 @@ scope T2: definition o1 equals Mod_def.Enum1.No definition o2 equals t1.e1 definition o3 equals t1.sr + assertion o1 = Mod_def.Enum1.No + assertion o2 = Mod_def.Enum1.Maybe + assertion o3 = $1000 ``` ```catala-test-inline