diff --git a/Makefile b/Makefile index c5058f32..aa452712 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ export # Dependencies ########################################## -EXECUTABLES = groff python3 colordiff node node npm ninja pandoc +EXECUTABLES = groff python3 node npm ninja pandoc K := $(foreach exec,$(EXECUTABLES),\ $(if $(shell which $(exec)),some string,$(warning [WARNING] No "$(exec)" executable found. \ Please install this executable for everything to work smoothly))) diff --git a/compiler/catala_utils/uid.ml b/compiler/catala_utils/uid.ml index f6d89215..3cdbaea1 100644 --- a/compiler/catala_utils/uid.ml +++ b/compiler/catala_utils/uid.ml @@ -35,7 +35,6 @@ module type Id = sig val hash : t -> int module Set : Set.S with type elt = t - module SetLabels : MoreLabels.Set.S with type elt = t and type t = Set.t module Map : Map.S with type key = t end @@ -43,7 +42,7 @@ module Make (X : Info) () : Id with type info = X.info = struct module Ordering = struct type t = { id : int; info : X.info } - let compare (x : t) (y : t) : int = compare x.id y.id + let compare (x : t) (y : t) : int = Int.compare x.id y.id let equal x y = Int.equal x.id y.id let format ppf t = X.format ppf t.info end @@ -64,8 +63,6 @@ module Make (X : Info) () : Id with type info = X.info = struct module Set = Set.Make (Ordering) module Map = Map.Make (Ordering) - module SetLabels = MoreLabels.Set.Make (Ordering) - module MapLabels = MoreLabels.Map.Make (Ordering) end module MarkedString = struct diff --git a/compiler/catala_utils/uid.mli b/compiler/catala_utils/uid.mli index deb198e2..aa479267 100644 --- a/compiler/catala_utils/uid.mli +++ b/compiler/catala_utils/uid.mli @@ -50,7 +50,6 @@ module type Id = sig val hash : t -> int module Set : Set.S with type elt = t - module SetLabels : MoreLabels.Set.S with type elt = t and type t = Set.t module Map : Map.S with type key = t end diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index 14d7103d..1ed339b1 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -29,26 +29,26 @@ type scope_input_var_ctx = { scope_input_typ : naked_typ; } +type 'm scope_ref = + | Local_scope_ref of 'm Ast.expr Var.t + | External_scope_ref of path * ScopeName.t Mark.pos + type 'm scope_sig_ctx = { scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *) - scope_sig_scope_var : 'm Ast.expr Var.t; (** Var representing the scope *) - scope_sig_input_var : 'm Ast.expr Var.t; - (** Var representing the scope input inside the scope func *) + scope_sig_scope_ref : 'm scope_ref; (** Var or external representing the scope *) scope_sig_input_struct : StructName.t; (** Scope input *) scope_sig_output_struct : StructName.t; (** Scope output *) scope_sig_in_fields : scope_input_var_ctx ScopeVar.Map.t; (** Mapping between the input scope variables and the input struct fields. *) - scope_sig_out_fields : StructField.t ScopeVar.Map.t; - (** Mapping between the output scope variables and the output struct - fields. TODO: could likely be removed now that we have it in the - program ctx *) } -type 'm scope_sigs_ctx = 'm scope_sig_ctx ScopeName.Map.t +type 'm scope_sigs_ctx = { + scope_sigs: 'm scope_sig_ctx ScopeName.Map.t; + scope_sigs_modules: 'm scope_sigs_ctx ModuleName.Map.t; +} type 'm ctx = { - structs : struct_ctx; - enums : enum_ctx; + decl_ctx : decl_ctx; scope_name : ScopeName.t option; scopes_parameters : 'm scope_sigs_ctx; toplevel_vars : ('m Ast.expr Var.t * naked_typ) TopdefName.Map.t; @@ -72,6 +72,15 @@ let pos_mark_mk (type a m) (e : (a, m) gexpr) : let pos_mark_as e = pos_mark (Mark.get e) in pos_mark, pos_mark_as +let rec module_scope_sig scope_sig_ctx path scope = + match path with + | [] -> ScopeName.Map.find scope scope_sig_ctx.scope_sigs + | (modname, mpos) :: path -> + match ModuleName.Map.find_opt modname scope_sig_ctx.scope_sigs_modules with + | None -> + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format modname + | Some sig_ctx -> module_scope_sig sig_ctx path scope + let merge_defaults ~(is_func : bool) (caller : (dcalc, 'm) boxed_gexpr) @@ -203,7 +212,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : let m = Mark.get e in match Mark.remove e with | EMatch { e = e1; name; cases = e_cases } -> - let enum_sig = EnumName.Map.find name ctx.enums in + let path, enum_sig = EnumName.Map.find name ctx.decl_ctx.ctx_enums in let d_cases, remaining_e_cases = (* FIXME: these checks should probably be moved to a better place *) EnumConstructor.Map.fold @@ -212,9 +221,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : try EnumConstructor.Map.find constructor e_cases with Not_found -> Message.raise_spanned_error (Expr.pos e) - "The constructor %a of enum %a is missing from this pattern \ + "The constructor %a of enum %a%a is missing from this pattern \ matching" - EnumConstructor.format constructor EnumName.format name + EnumConstructor.format constructor Print.path path EnumName.format name in let case_d = translate_expr ctx case_e in ( EnumConstructor.Map.add constructor case_d d_cases, @@ -224,16 +233,19 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : in if not (EnumConstructor.Map.is_empty remaining_e_cases) then Message.raise_spanned_error (Expr.pos e) - "Pattern matching is incomplete for enum %a: missing cases %a" + "Pattern matching is incomplete for enum %a%a: missing cases %a" + Print.path path EnumName.format name (EnumConstructor.Map.format_keys ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")) remaining_e_cases; let e1 = translate_expr ctx e1 in - Expr.ematch e1 name d_cases m - | EScopeCall { scope; args } -> + Expr.ematch ~e:e1 ~name ~cases:d_cases m + | EScopeCall { path; scope; args } -> let pos = Expr.mark_pos m in - let sc_sig = ScopeName.Map.find scope ctx.scopes_parameters in + let sc_sig = + module_scope_sig ctx.scopes_parameters path scope + in let in_var_map = ScopeVar.Map.merge (fun var_name (str_field : scope_input_var_ctx option) expr -> @@ -280,11 +292,17 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : in_var_map StructField.Map.empty in let arg_struct = - Expr.estruct sc_sig.scope_sig_input_struct field_map (mark_tany m pos) + Expr.estruct ~name:sc_sig.scope_sig_input_struct ~fields:field_map (mark_tany m pos) in let called_func = + let m = mark_tany m pos in + let e = match sc_sig.scope_sig_scope_ref with + | Local_scope_ref v -> Expr.evar v m + | External_scope_ref (path, name) -> + Expr.eexternal ~path ~name:(Mark.map (fun s -> External_scope s) name) m + in tag_with_log_entry - (Expr.evar sc_sig.scope_sig_scope_var (mark_tany m pos)) + e BeginCall [ScopeName.get_info scope; Mark.add (Expr.pos e) "direct"] in @@ -332,15 +350,15 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : (* result_eta_expanded = { struct_output_function_field = lambda x -> log (struct_output.struct_output_function_field x) ... } *) let result_eta_expanded = - Expr.estruct sc_sig.scope_sig_output_struct - (StructField.Map.mapi + Expr.estruct ~name:sc_sig.scope_sig_output_struct + ~fields:(StructField.Map.mapi (fun field typ -> let original_field_expr = Expr.estructaccess - (Expr.make_var result_var + ~e:(Expr.make_var result_var (Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e))) - field sc_sig.scope_sig_output_struct (Expr.with_ty m typ) + ~field ~name:sc_sig.scope_sig_output_struct (Expr.with_ty m typ) in match Mark.remove typ with | TArrow (ts_in, t_out) -> @@ -387,7 +405,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : EndCall f_markings) ts_in (Expr.pos e) | _ -> original_field_expr) - (StructName.Map.find sc_sig.scope_sig_output_struct ctx.structs)) + (snd (StructName.Map.find sc_sig.scope_sig_output_struct ctx.decl_ctx.ctx_structs))) (Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e)) in (* Here we have to go through an if statement that records a decision being @@ -439,10 +457,10 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : match ctx.scope_name, Mark.remove f with | Some sname, ELocation loc -> ( match loc with - | ScopelangScopeVar (v, _) -> + | ScopelangScopeVar { name = (v, _); _ } -> [ScopeName.get_info sname; ScopeVar.get_info v] - | SubScopeVar (s, _, (v, _)) -> - [ScopeName.get_info s; ScopeVar.get_info v] + | SubScopeVar {scope; var = (v, _); _} -> + [ScopeName.get_info scope; ScopeVar.get_info v] | ToplevelVar _ -> []) | _ -> [] in @@ -453,8 +471,8 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : in let new_args = List.map (translate_expr ctx) args in let input_typs, output_typ = - (* NOTE: this is a temporary solution, it works because it's assume that - all function calls are from scope variable. However, this will change + (* NOTE: this is a temporary solution, it works because it's assumed that + all function calls are from scope variables. However, this will change -- for more information see https://github.com/CatalaLang/catala/pull/280#discussion_r898851693. *) let retrieve_in_and_out_typ_or_any var vars = @@ -465,19 +483,20 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : | _ -> ListLabels.map new_args ~f:(fun _ -> TAny), TAny in match Mark.remove f with - | ELocation (ScopelangScopeVar var) -> + | ELocation (ScopelangScopeVar {name = var}) -> retrieve_in_and_out_typ_or_any var ctx.scope_vars - | ELocation (SubScopeVar (_, sname, var)) -> + | ELocation (SubScopeVar { alias; var ; _}) -> ctx.subscope_vars - |> SubScopeName.Map.find (Mark.remove sname) + |> SubScopeName.Map.find (Mark.remove alias) |> retrieve_in_and_out_typ_or_any var - | ELocation (ToplevelVar tvar) -> ( - let _, typ = TopdefName.Map.find (Mark.remove tvar) ctx.toplevel_vars in - match typ with - | TArrow (tin, (tout, _)) -> List.map Mark.remove tin, tout - | _ -> - Message.raise_spanned_error (Expr.pos e) - "Application of non-function toplevel variable") + | ELocation (ToplevelVar { path; name }) -> ( + let decl_ctx = Program.module_ctx ctx.decl_ctx path in + let typ = TopdefName.Map.find (Mark.remove name) decl_ctx.ctx_topdefs in + match Mark.remove typ with + | TArrow (tin, (tout, _)) -> List.map Mark.remove tin, tout + | _ -> + Message.raise_spanned_error (Expr.pos e) + "Application of non-function toplevel variable") | _ -> ListLabels.map new_args ~f:(fun _ -> TAny), TAny in @@ -522,10 +541,10 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : Expr.edefault (List.map (translate_expr ctx) excepts) (translate_expr ctx just) (translate_expr ctx cons) m - | ELocation (ScopelangScopeVar a) -> + | ELocation (ScopelangScopeVar { name = a }) -> let v, _, _ = ScopeVar.Map.find (Mark.remove a) ctx.scope_vars in Expr.evar v m - | ELocation (SubScopeVar (_, s, a)) -> ( + | ELocation (SubScopeVar { alias = s; var = a; _ }) -> ( try let v, _, _ = ScopeVar.Map.find (Mark.remove a) @@ -545,13 +564,15 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : %a's results. Maybe you forgot to qualify it as an output?" SubScopeName.format (Mark.remove s) ScopeVar.format (Mark.remove a) SubScopeName.format (Mark.remove s)) - | ELocation (ToplevelVar v) -> - let v, _ = TopdefName.Map.find (Mark.remove v) ctx.toplevel_vars in + | ELocation (ToplevelVar { path = []; name }) -> + let v, _ = TopdefName.Map.find (Mark.remove name) ctx.toplevel_vars in Expr.evar v m + | ELocation (ToplevelVar { path = _::_ as path; name }) -> + Expr.eexternal ~path ~name:(Mark.map (fun n -> External_value n) name) m | EOp { op = Add_dat_dur _; tys } -> Expr.eop (Add_dat_dur ctx.date_rounding) tys m | EOp { op; tys } -> Expr.eop (Operator.translate op) tys m - | ( EVar _ | EAbs _ | ELit _ | EExternal _ | EStruct _ | EStructAccess _ + | ( EVar _ | EAbs _ | ELit _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EEmptyError | EErrorOnEmpty _ | EArray _ | EIfThenElse _ ) as e -> Expr.map ~f:(translate_expr ctx) (e, m) @@ -569,7 +590,7 @@ let translate_rule 'm Ast.expr scope_body_expr Bindlib.box) * 'm ctx = match rule with - | Definition ((ScopelangScopeVar a, var_def_pos), tau, a_io, e) -> + | Definition ((ScopelangScopeVar { name = a }, var_def_pos), tau, a_io, e) -> let pos_mark, pos_mark_as = pos_mark_mk e in let a_name = ScopeVar.get_info (Mark.remove a) in let a_var = Var.make (Mark.remove a_name) in @@ -615,7 +636,7 @@ let translate_rule ctx.scope_vars; } ) | Definition - ( (SubScopeVar (_subs_name, subs_index, subs_var), var_def_pos), + ( (SubScopeVar { alias = subs_index; var = subs_var; _ }, var_def_pos), tau, a_io, e ) -> @@ -681,8 +702,12 @@ let translate_rule (* A global variable can't be defined locally. The [Definition] constructor could be made more specific to avoid this case, but the added complexity didn't seem worth it *) - | Call (subname, subindex, m) -> - let subscope_sig = ScopeName.Map.find subname ctx.scopes_parameters in + | Call ((path, subname), subindex, m) -> + let subscope_sig = module_scope_sig ctx.scopes_parameters path subname in + let scope_sig_decl = + ScopeName.Map.find subname + (Program.module_ctx ctx.decl_ctx path).ctx_scopes + in let all_subscope_vars = subscope_sig.scope_sig_local_vars in let all_subscope_input_vars = List.filter @@ -698,7 +723,14 @@ let translate_rule Mark.remove var_ctx.scope_var_io.Desugared.Ast.io_output) all_subscope_vars in - let scope_dcalc_var = subscope_sig.scope_sig_scope_var in + let pos_call = Mark.get (SubScopeName.get_info subindex) in + let scope_dcalc_ref = + let m = mark_tany m pos_call in + match subscope_sig.scope_sig_scope_ref with + | Local_scope_ref var -> Expr.make_var var m + | External_scope_ref (path, name) -> + Expr.eexternal ~path ~name:(Mark.map (fun n -> External_scope n) name) m + in let called_scope_input_struct = subscope_sig.scope_sig_input_struct in let called_scope_return_struct = subscope_sig.scope_sig_output_struct in let subscope_vars_defined = @@ -708,7 +740,6 @@ let translate_rule let subscope_var_not_yet_defined subvar = not (ScopeVar.Map.mem subvar subscope_vars_defined) in - let pos_call = Mark.get (SubScopeName.get_info subindex) in let subscope_args = List.fold_left (fun acc (subvar : scope_var_ctx) -> @@ -734,7 +765,7 @@ let translate_rule StructField.Map.empty all_subscope_input_vars in let subscope_struct_arg = - Expr.estruct called_scope_input_struct subscope_args + Expr.estruct ~name:called_scope_input_struct ~fields:subscope_args (mark_tany m pos_call) in let all_subscope_output_vars_dcalc = @@ -751,7 +782,7 @@ let translate_rule in let subscope_func = tag_with_log_entry - (Expr.make_var scope_dcalc_var (mark_tany m pos_call)) + scope_dcalc_ref BeginCall [ sigma_name, pos_sigma; @@ -790,7 +821,7 @@ let translate_rule (fun (var_ctx, v) next -> let field = ScopeVar.Map.find var_ctx.scope_var_name - subscope_sig.scope_sig_out_fields + scope_sig_decl.out_struct_fields in Bindlib.box_apply2 (fun next r -> @@ -849,6 +880,7 @@ let translate_rule let translate_rules (ctx : 'm ctx) + (scope_name : ScopeName.t) (rules : 'm Scopelang.Ast.rule list) ((sigma_name, pos_sigma) : Uid.MarkedString.info) (mark : 'm mark) @@ -864,12 +896,15 @@ let translate_rules ((fun next -> next), ctx) rules in + let scope_sig_decl = + ScopeName.Map.find scope_name ctx.decl_ctx.ctx_scopes + in let return_exp = - Expr.estruct scope_sig.scope_sig_output_struct - (ScopeVar.Map.fold + Expr.estruct ~name:scope_sig.scope_sig_output_struct + ~fields:(ScopeVar.Map.fold (fun var (dcalc_var, _, io) acc -> if Mark.remove io.Desugared.Ast.io_output then - let field = ScopeVar.Map.find var scope_sig.scope_sig_out_fields in + let field = ScopeVar.Map.find var scope_sig_decl.out_struct_fields in StructField.Map.add field (Expr.make_var dcalc_var (mark_tany mark pos_sigma)) acc @@ -883,6 +918,7 @@ let translate_rules (Expr.Box.lift return_exp)), new_ctx ) +(* From a scope declaration and definitions, create the corresponding scope body wrapped in the appropriate call convention. *) let translate_scope_decl (ctx : 'm ctx) (scope_name : ScopeName.t) @@ -890,7 +926,7 @@ let translate_scope_decl 'm Ast.expr scope_body Bindlib.box * struct_ctx = let sigma_info = ScopeName.get_info sigma.scope_decl_name in let scope_sig = - ScopeName.Map.find sigma.scope_decl_name ctx.scopes_parameters + ScopeName.Map.find sigma.scope_decl_name ctx.scopes_parameters.scope_sigs in let scope_variables = scope_sig.scope_sig_local_vars in let ctx = { ctx with scope_name = Some scope_name } in @@ -926,12 +962,24 @@ let translate_scope_decl | None -> AbortOnRound in let ctx = { ctx with date_rounding } in - let scope_input_var = scope_sig.scope_sig_input_var in + let scope_input_var = + Var.make (Mark.remove (ScopeName.get_info scope_name) ^ "_in") + in let scope_input_struct_name = scope_sig.scope_sig_input_struct in let scope_return_struct_name = scope_sig.scope_sig_output_struct in let pos_sigma = Mark.get sigma_info in + let scope_mark = + (* Find a witness of a mark in the definitions *) + match sigma.scope_decl_rules with + | [] -> + (* Todo: are we sure this can't happen in normal code ? E.g. is calling a scope which only defines input variables already an error at this stage or not ? *) + Message.raise_spanned_error pos_sigma "Scope %a has no content" ScopeName.format scope_name + | (Definition (_,_,_,(_,m)) | Assertion (_,m) | Call (_,_,m)) :: _ -> + m + in let rules_with_return_expr, ctx = - translate_rules ctx sigma.scope_decl_rules sigma_info sigma.scope_mark + translate_rules ctx scope_name sigma.scope_decl_rules sigma_info + scope_mark scope_sig in let scope_variables = @@ -982,14 +1030,25 @@ let translate_scope_decl scope_let_expr = ( EStructAccess { name = scope_input_struct_name; e = r; field }, - mark_tany sigma.scope_mark pos_sigma ); + mark_tany scope_mark pos_sigma ); }) (Bindlib.bind_var v next) (Expr.Box.lift (Expr.make_var scope_input_var - (mark_tany sigma.scope_mark pos_sigma)))) + (mark_tany scope_mark pos_sigma)))) scope_input_variables next in + let scope_body = + Bindlib.box_apply + (fun scope_body_expr -> + { + scope_body_expr; + scope_body_input_struct = scope_input_struct_name; + scope_body_output_struct = scope_return_struct_name; + }) + (Bindlib.bind_var scope_input_var + (input_destructurings rules_with_return_expr)) + in let field_map = List.fold_left (fun acc (var_ctx, _) -> @@ -1001,17 +1060,9 @@ let translate_scope_decl StructField.Map.empty scope_input_variables in let new_struct_ctx = - StructName.Map.singleton scope_input_struct_name field_map + StructName.Map.singleton scope_input_struct_name ([], field_map) in - ( Bindlib.box_apply - (fun scope_body_expr -> - { - scope_body_expr; - scope_body_input_struct = scope_input_struct_name; - scope_body_output_struct = scope_return_struct_name; - }) - (Bindlib.bind_var scope_input_var - (input_destructurings rules_with_return_expr)), + ( scope_body, new_struct_ctx ) let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = @@ -1021,23 +1072,29 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = Scopelang.Dependency.get_defs_ordering defs_dependencies in let decl_ctx = prgm.program_ctx in + Message.emit_debug "prog scopes: %a@ modules: %a" + (ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space) prgm.program_scopes + (ModuleName.Map.format + (fun fmt prg -> ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space fmt prg.Scopelang.Ast.program_scopes)) prgm.program_modules; let sctx : 'm scope_sigs_ctx = - ScopeName.Map.mapi - (fun scope_name scope -> - let scope_dvar = - Var.make - (Mark.remove - (ScopeName.get_info scope.Scopelang.Ast.scope_decl_name)) - in - let scope_return = ScopeName.Map.find scope_name decl_ctx.ctx_scopes in - let scope_input_var = - Var.make (Mark.remove (ScopeName.get_info scope_name) ^ "_in") - in - let scope_input_struct_name = - StructName.fresh - (Mark.map (fun s -> s ^ "_in") (ScopeName.get_info scope_name)) - in - let scope_sig_in_fields = + let process_scope_sig (scope_path, scope_name) scope = + Message.emit_debug "process_scope_sig %a%a (%a)" + Print.path scope_path ScopeName.format scope_name ScopeName.format scope.Scopelang.Ast.scope_decl_name; + let scope_ref = + match scope_path with + | [] -> + let v = Var.make (Mark.remove (ScopeName.get_info scope_name)) in + Local_scope_ref v + | path -> + External_scope_ref (path, Mark.copy (ScopeName.get_info scope_name) scope_name) + in + let scope_info = + try + ScopeName.Map.find scope_name (Program.module_ctx decl_ctx scope_path).ctx_scopes + with Not_found -> Message.raise_spanned_error (Mark.get (ScopeName.get_info scope_name)) "Could not find scope %a%a" Print.path scope_path ScopeName.format scope_name + in + let scope_sig_in_fields = + (* Output fields have already been generated and added to the program ctx at this point, because they are visible to the user (manipulated as the return type of ScopeCalls) ; but input fields are used purely internally and need to be created here to implement the call convention for scopes. *) ScopeVar.Map.filter_map (fun dvar (typ, vis) -> match Mark.remove vis.Desugared.Ast.io_input with @@ -1051,7 +1108,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = scope_input_io = vis.Desugared.Ast.io_input; scope_input_typ = Mark.remove typ; }) - scope.scope_sig + scope.Scopelang.Ast.scope_sig in { scope_sig_local_vars = @@ -1063,15 +1120,55 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = scope_var_io = vis; }) (ScopeVar.Map.bindings scope.scope_sig); - scope_sig_scope_var = scope_dvar; - scope_sig_input_var = scope_input_var; - scope_sig_input_struct = scope_input_struct_name; - scope_sig_output_struct = scope_return.out_struct_name; + scope_sig_scope_ref = scope_ref; + scope_sig_input_struct = scope_info.in_struct_name; + scope_sig_output_struct = scope_info.out_struct_name; scope_sig_in_fields; - scope_sig_out_fields = scope_return.out_struct_fields; - }) - prgm.Scopelang.Ast.program_scopes + } + in + let rec process_modules path prg = + { scope_sigs = + ScopeName.Map.mapi + (fun scope_name (scope_decl, _) -> + process_scope_sig (path, scope_name) scope_decl) + prg.Scopelang.Ast.program_scopes; + scope_sigs_modules = + ModuleName.Map.mapi (fun modname prg -> + process_modules (path @ [modname, Pos.no_pos]) prg) + prg.Scopelang.Ast.program_modules; + } + in + { scope_sigs = + ScopeName.Map.mapi + (fun scope_name (scope_decl, _) -> + process_scope_sig ([], scope_name) scope_decl) + prgm.Scopelang.Ast.program_scopes; + scope_sigs_modules = + ModuleName.Map.mapi (fun modname prg -> + process_modules [modname, Pos.no_pos] prg) + prgm.Scopelang.Ast.program_modules; + } in + let rec gather_module_in_structs acc path sctx = + (* Expose all added in_structs from submodules at toplevel *) + ModuleName.Map.fold (fun modname scope_sigs acc -> + let path = path @ [modname, Pos.no_pos] in + let acc = gather_module_in_structs acc path scope_sigs.scope_sigs_modules in + ScopeName.Map.fold (fun _ scope_sig_ctx acc -> + let fields = + ScopeVar.Map.fold (fun _ sivc acc -> + let pos = Mark.get (StructField.get_info sivc.scope_input_name) in + StructField.Map.add sivc.scope_input_name (sivc.scope_input_typ, pos) acc) + scope_sig_ctx.scope_sig_in_fields StructField.Map.empty + in + StructName.Map.add scope_sig_ctx.scope_sig_input_struct + (path, fields) acc) + scope_sigs.scope_sigs acc + ) + sctx + acc + in + let decl_ctx = { decl_ctx with ctx_structs = gather_module_in_structs decl_ctx.ctx_structs [] sctx.scope_sigs_modules } in let top_ctx = let toplevel_vars = TopdefName.Map.mapi @@ -1080,8 +1177,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = prgm.Scopelang.Ast.program_topdefs in { - structs = decl_ctx.ctx_structs; - enums = decl_ctx.ctx_enums; + decl_ctx; scope_name = None; scopes_parameters = sctx; scope_vars = ScopeVar.Map.empty; @@ -1109,16 +1205,24 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = | Scopelang.Dependency.Scope scope_name -> let scope = ScopeName.Map.find scope_name prgm.program_scopes in let scope_body, scope_in_struct = - translate_scope_decl ctx scope_name scope + translate_scope_decl ctx scope_name (Mark.remove scope) + in + let scope_var = + match (ScopeName.Map.find scope_name sctx.scope_sigs).scope_sig_scope_ref with + | Local_scope_ref v -> v + | External_scope_ref _ -> assert false in ( { ctx with - structs = - StructName.Map.union - (fun _ _ -> assert false) - ctx.structs scope_in_struct; + decl_ctx = + { ctx.decl_ctx with + ctx_structs = + StructName.Map.union + (fun _ _ -> assert false) + ctx.decl_ctx.ctx_structs scope_in_struct; + } }, - (ScopeName.Map.find scope_name sctx).scope_sig_scope_var, + scope_var, Bindlib.box_apply (fun body -> ScopeDef (scope_name, body)) scope_body ) @@ -1131,7 +1235,8 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = ctx ) in let items, ctx = translate_defs top_ctx defs_ordering in + (* WIP TODO FIXME HERE: the scopes in submodules are not translated here it seems, and their input structs not added to decl_ctx (see From_surface:1476 for decl_ctx flattening info) *) { code_items = Bindlib.unbox items; - decl_ctx = { decl_ctx with ctx_structs = ctx.structs }; + decl_ctx = ctx.decl_ctx; } diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 3738b35c..734037b0 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -215,7 +215,7 @@ type var_or_states = WholeVar | States of StateName.t list type scope = { scope_vars : var_or_states ScopeVar.Map.t; - scope_sub_scopes : ScopeName.t SubScopeName.Map.t; + scope_sub_scopes : (path * ScopeName.t) SubScopeName.Map.t; scope_uid : ScopeName.t; scope_defs : scope_def ScopeDef.Map.t; scope_assertions : assertion AssertionName.Map.t; @@ -227,6 +227,7 @@ type program = { program_scopes : scope ScopeName.Map.t; program_topdefs : (expr option * typ) TopdefName.Map.t; program_ctx : decl_ctx; + program_modules : program ModuleName.Map.t; } let rec locations_used e : LocationSet.t = @@ -247,11 +248,11 @@ let free_variables (def : rule RuleName.Map.t) : Pos.t ScopeDef.Map.t = (fun (loc, loc_pos) acc -> let usage = match loc with - | DesugaredScopeVar (v, st) -> Some (ScopeDef.Var (Mark.remove v, st)) - | SubScopeVar (_, sub_index, sub_var) -> + | DesugaredScopeVar { name; state } -> Some (ScopeDef.Var (Mark.remove name, state)) + | SubScopeVar { alias; var; _ } -> Some (ScopeDef.SubScopeVar - (Mark.remove sub_index, Mark.remove sub_var, Mark.get sub_index)) + (Mark.remove alias, Mark.remove var, Mark.get alias)) | ToplevelVar _ -> None in match usage with diff --git a/compiler/desugared/ast.mli b/compiler/desugared/ast.mli index 6954fba0..77dd3054 100644 --- a/compiler/desugared/ast.mli +++ b/compiler/desugared/ast.mli @@ -104,7 +104,7 @@ type var_or_states = WholeVar | States of StateName.t list type scope = { scope_vars : var_or_states ScopeVar.Map.t; - scope_sub_scopes : ScopeName.t SubScopeName.Map.t; + scope_sub_scopes : (path * ScopeName.t) SubScopeName.Map.t; scope_uid : ScopeName.t; scope_defs : scope_def ScopeDef.Map.t; scope_assertions : assertion AssertionName.Map.t; @@ -116,6 +116,7 @@ type program = { program_scopes : scope ScopeName.Map.t; program_topdefs : (expr option * typ) TopdefName.Map.t; program_ctx : decl_ctx; + program_modules : program ModuleName.Map.t; } (** {1 Helpers} *) diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 64755a52..6905c2af 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -261,9 +261,9 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = (fun used_var g -> let edge_from = match Mark.remove used_var with - | DesugaredScopeVar (v, s) -> Some (Vertex.Var (Mark.remove v, s)) - | SubScopeVar (_, subscope_name, _) -> - Some (Vertex.SubScope (Mark.remove subscope_name)) + | DesugaredScopeVar { name; state } -> Some (Vertex.Var (Mark.remove name, state)) + | SubScopeVar { alias; _ } -> + Some (Vertex.SubScope (Mark.remove alias)) | ToplevelVar _ -> None (* we don't add this dependency because toplevel definitions are outside the scope *) diff --git a/compiler/desugared/disambiguate.ml b/compiler/desugared/disambiguate.ml index 9de4f466..ba4946e6 100644 --- a/compiler/desugared/disambiguate.ml +++ b/compiler/desugared/disambiguate.ml @@ -62,11 +62,40 @@ let scope ctx env scope = { scope with scope_defs; scope_assertions } let program prg = + let base_typing_env prg = + let env = + TopdefName.Map.fold + (fun name (_e, ty) env -> Typing.Env.add_toplevel_var name ty env) + prg.program_topdefs + (Typing.Env.empty prg.program_ctx) + in + let env = + ScopeName.Map.fold + (fun scope_name scope env -> + let vars = + ScopeDef.Map.fold + (fun var def vars -> + match var with + | Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars + | SubScopeVar _ -> vars) + scope.scope_defs ScopeVar.Map.empty + in + Typing.Env.add_scope scope_name ~vars env) + prg.program_scopes env + in + env + in + let rec build_typing_env prg = + ModuleName.Map.fold (fun modname prg -> + Typing.Env.add_module modname ~module_env:(build_typing_env prg)) + prg.program_modules + (base_typing_env prg) + in let env = - TopdefName.Map.fold - (fun name (_e, ty) env -> Typing.Env.add_toplevel_var name ty env) - prg.program_topdefs - (Typing.Env.empty prg.program_ctx) + ModuleName.Map.fold (fun modname prg -> + Typing.Env.add_module modname ~module_env:(build_typing_env prg)) + prg.program_modules + (base_typing_env prg) in let program_topdefs = TopdefName.Map.map @@ -76,20 +105,6 @@ let program prg = | None, ty -> None, ty) prg.program_topdefs in - let env = - ScopeName.Map.fold - (fun scope_name scope env -> - let vars = - ScopeDef.Map.fold - (fun var def vars -> - match var with - | Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars - | SubScopeVar _ -> vars) - scope.scope_defs ScopeVar.Map.empty - in - Typing.Env.add_scope scope_name ~vars env) - prg.program_scopes env - in let program_scopes = ScopeName.Map.map (scope prg.program_ctx env) prg.program_scopes in diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index eab2d0a5..dac41054 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -34,7 +34,7 @@ module Runtime = Runtime_ocaml.Runtime the operator suffixes for explicit typing. See {!modules: Shared_ast.Operator} for detail. *) -let translate_binop : Surface.Ast.binop -> Pos.t -> Ast.expr boxed = +let translate_binop : S.binop -> Pos.t -> Ast.expr boxed = fun op pos -> let op_expr op tys = Expr.eop op (List.map (Mark.add pos) tys) (Untyped { pos }) @@ -104,7 +104,7 @@ let translate_binop : Surface.Ast.binop -> Pos.t -> Ast.expr boxed = | S.Neq -> assert false (* desugared already *) | S.Concat -> op_expr Concat [TArray (TAny, pos); TArray (TAny, pos)] -let translate_unop (op : Surface.Ast.unop) pos : Ast.expr boxed = +let translate_unop (op : S.unop) pos : Ast.expr boxed = let op_expr op ty = Expr.eop op [Mark.add pos ty] (Untyped { pos }) in match op with | S.Not -> op_expr Not (TLit TBool) @@ -134,12 +134,12 @@ let raise_error_cons_not_found "The name of this constructor has not been defined before@ (it's probably \ a typographical error)." -let disambiguate_constructor +let rec disambiguate_constructor (ctxt : Name_resolution.context) - (constructor : (S.path * S.uident Mark.pos) Mark.pos list) + (constructor0 : (S.path * S.uident Mark.pos) Mark.pos list) (pos : Pos.t) : EnumName.t * EnumConstructor.t = let path, constructor = - match constructor with + match constructor0 with | [c] -> Mark.remove c | _ -> Message.raise_spanned_error pos @@ -172,7 +172,13 @@ let disambiguate_constructor with Not_found -> Message.raise_spanned_error (Mark.get enum) "Enum %s has not been defined before" (Mark.remove enum)) - | _ -> Message.raise_spanned_error pos "Qualified paths are not supported yet" + | (modname, mpos)::path -> + match ModuleName.Map.find_opt modname ctxt.modules with + | None -> + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format modname + | Some ctxt -> + let constructor = List.map (Mark.map (fun (_, c) -> path, c)) constructor0 in + disambiguate_constructor ctxt constructor pos let int100 = Runtime.integer_of_int 100 let rat100 = Runtime.decimal_of_integer int100 @@ -204,19 +210,22 @@ let rec translate_expr (scope : ScopeName.t option) (inside_definition_of : Ast.ScopeDef.t Mark.pos option) (ctxt : Name_resolution.context) - (expr : Surface.Ast.expression) : Ast.expr boxed = + (local_vars : Ast.expr Var.t Ident.Map.t) + (expr : S.expression) : Ast.expr boxed = let scope_vars = match scope with | None -> Ident.Map.empty | Some s -> (ScopeName.Map.find s ctxt.scopes).var_idmap in - let rec_helper = translate_expr scope inside_definition_of ctxt in + let rec_helper ?(local_vars=local_vars) e = + translate_expr scope inside_definition_of ctxt local_vars e + in let pos = Mark.get expr in let emark = Untyped { pos } in match Mark.remove expr with | Paren e -> rec_helper e | Binop - ( (Surface.Ast.And, _pos_op), + ( (S.And, _pos_op), ( TestMatchCase (e1_sub, ((constructors, Some binding), pos_pattern)), _pos_e1 ), e2 ) -> @@ -234,16 +243,15 @@ let rec translate_expr (Expr.elit (LBool false) emark) [tau] pos else - let ctxt, binding_var = - Name_resolution.add_def_local_var ctxt (Mark.remove binding) - in - let e2 = translate_expr scope inside_definition_of ctxt e2 in + let binding_var = Var.make (Mark.remove binding) in + let local_vars = Ident.Map.add (Mark.remove binding) binding_var local_vars in + let e2 = rec_helper ~local_vars e2 in Expr.make_abs [| binding_var |] e2 [tau] pos) (EnumName.Map.find enum_uid ctxt.enums) in Expr.ematch - (translate_expr scope inside_definition_of ctxt e1_sub) - enum_uid cases emark + ~e:(rec_helper e1_sub) + ~name:enum_uid ~cases emark | Binop ((((S.And | S.Or | S.Xor), _) as op), e1, e2) -> check_formula op e1; check_formula op e2; @@ -311,7 +319,7 @@ let rec translate_expr | Ident ([], (x, pos)) -> ( (* first we check whether this is a local var, then we resort to scope-wide variables, then global variables *) - match Ident.Map.find_opt x ctxt.local_var_idmap with + match Ident.Map.find_opt x local_vars with | Some uid -> Expr.make_var uid emark (* the whole box thing is to accomodate for this case *) @@ -343,20 +351,18 @@ let rec translate_expr else (* Tricky: we have to retrieve in the list the previous state with respect to the state that we are defining. *) - let correct_state = ref None in - ignore - (List.fold_left - (fun previous_state state -> - if StateName.equal inside_def_state state then - correct_state := previous_state; - Some state) - None states); - !correct_state) + let rec find_prev_state = function + | [] -> None + | st0 :: st1 :: _ when StateName.equal inside_def_state st1 -> + Some st0 + | _ :: states -> find_prev_state states + in + find_prev_state states) | _ -> (* we take the last state in the chain *) Some (List.hd (List.rev states))) in - Expr.elocation (DesugaredScopeVar ((uid, pos), x_state)) emark + Expr.elocation (DesugaredScopeVar { name = uid, pos; state = x_state }) emark | Some (SubScope _) (* Note: allowing access to a global variable with the same name as a subscope is disputable, but I see no good reason to forbid it either *) @@ -364,56 +370,74 @@ let rec translate_expr match Ident.Map.find_opt x ctxt.topdefs with | Some v -> Expr.elocation - (ToplevelVar (v, Mark.get (TopdefName.get_info v))) + (ToplevelVar { path = []; name = v, Mark.get (TopdefName.get_info v) }) emark | None -> Name_resolution.raise_unknown_identifier "for a local, scope-wide or global variable" (x, pos)))) - | Surface.Ast.Ident (path, x) -> - let path = List.map Mark.remove path in - Expr.eexternal (path, Mark.remove x) emark + | Ident (path, name) -> + let ctxt = Name_resolution.module_ctx ctxt path in + (match Ident.Map.find_opt (Mark.remove name) ctxt.topdefs with + | Some v -> + Expr.elocation + (ToplevelVar { path; name = v, Mark.get (TopdefName.get_info v) }) + emark + | None -> + Name_resolution.raise_unknown_identifier + "for an external variable" name) | Dotted (e, ((path, x), _ppos)) -> ( match path, Mark.remove e with | [], Ident ([], (y, _)) when Option.fold scope ~none:false ~some:(fun s -> Name_resolution.is_subscope_uid s ctxt y) -> (* In this case, y.x is a subscope variable *) - let subscope_uid, subscope_real_uid = + let subscope_uid, (subscope_path, subscope_real_uid) = match Ident.Map.find y scope_vars with | SubScope (sub, sc) -> sub, sc | ScopeVar _ -> assert false in let subscope_var_uid = + let ctxt = Name_resolution.module_ctx ctxt subscope_path in Name_resolution.get_var_uid subscope_real_uid ctxt x in Expr.elocation - (SubScopeVar - (subscope_real_uid, (subscope_uid, pos), (subscope_var_uid, pos))) + (SubScopeVar { + path = subscope_path; + scope = subscope_real_uid; + alias = (subscope_uid, pos); + var = (subscope_var_uid, pos) + }) emark | _ -> (* In this case e.x is the struct field x access of expression e *) - let e = translate_expr scope inside_definition_of ctxt e in - let str = - match path with + let e = rec_helper e in + let rec get_str ctxt = function | [] -> None | [c] -> ( try Some (Name_resolution.get_struct ctxt c) with Not_found -> Message.raise_spanned_error (Mark.get c) "Structure %s was not declared" (Mark.remove c)) - | _ -> - Message.raise_spanned_error pos - "Qualified paths are not supported yet" + | (modname, mpos) :: path -> + match ModuleName.Map.find_opt modname ctxt.modules with + | None -> + Message.raise_spanned_error mpos + "Module %a not found" ModuleName.format modname + | Some ctxt -> + get_str ctxt path in - Expr.edstructaccess e (Mark.remove x) str emark) + Expr.edstructaccess ~e ~field:(Mark.remove x) ~name_opt:(get_str ctxt path) ~path emark) | FunCall (f, args) -> Expr.eapp (rec_helper f) (List.map rec_helper args) emark - | ScopeCall ((([], sc_name), _), fields) -> + | ScopeCall (((path, id), _), fields) -> if scope = None then Message.raise_spanned_error pos "Scope calls are not allowed outside of a scope"; - let called_scope = Name_resolution.get_scope ctxt sc_name in - let scope_def = ScopeName.Map.find called_scope ctxt.scopes in + let called_scope, scope_def = + let ctxt = Name_resolution.module_ctx ctxt path in + let uid = Name_resolution.get_scope ctxt id in + uid, ScopeName.Map.find uid ctxt.scopes + in let in_struct = List.fold_left (fun acc (fld_id, e) -> @@ -444,16 +468,15 @@ let rec translate_expr acc) ScopeVar.Map.empty fields in - Expr.escopecall called_scope in_struct emark - | ScopeCall (((_, _sc_name), _), _fields) -> - Message.raise_spanned_error pos "Qualified paths are not supported yet" + Expr.escopecall ~path ~scope:called_scope ~args:in_struct emark | LetIn (x, e1, e2) -> - let ctxt, v = Name_resolution.add_def_local_var ctxt (Mark.remove x) in + let v = Var.make (Mark.remove x) in + let local_vars = Ident.Map.add (Mark.remove x) v local_vars in let tau = TAny, Mark.get x in (* This type will be resolved in Scopelang.Desambiguation *) let fn = Expr.make_abs [| v |] - (translate_expr scope inside_definition_of ctxt e2) + (rec_helper ~local_vars e2) [tau] pos in Expr.eapp fn [rec_helper e1] emark @@ -484,7 +507,7 @@ let rec translate_expr Message.raise_multispanned_error [None, Mark.get f_e; None, Expr.pos e_field] "The field %a has been defined twice:" StructField.format f_uid); - let f_e = translate_expr scope inside_definition_of ctxt f_e in + let f_e = rec_helper f_e in StructField.Map.add f_uid f_e s_fields) StructField.Map.empty fields in @@ -497,12 +520,12 @@ let rec translate_expr StructField.format expected_f) expected_s_fields; - Expr.estruct s_uid s_fields emark + Expr.estruct ~name:s_uid ~fields:s_fields emark | StructLit (((_, _s_name), _), _fields) -> Message.raise_spanned_error pos "Qualified paths are not supported yet" | EnumInject (((path, (constructor, pos_constructor)), _), payload) -> ( - let possible_c_uids = - try Ident.Map.find constructor ctxt.constructor_idmap + let get_possible_c_uids ctxt = + try Ident.Map.find constructor ctxt.Name_resolution.constructor_idmap with Not_found -> raise_error_cons_not_found ctxt (constructor, pos_constructor) in @@ -510,8 +533,9 @@ let rec translate_expr match path with | [] -> + let possible_c_uids = get_possible_c_uids ctxt in if - (* No constructor name was specified *) + (* No enum name was specified *) EnumName.Map.cardinal possible_c_uids > 1 then Message.raise_spanned_error pos_constructor @@ -522,43 +546,46 @@ let rec translate_expr possible_c_uids else let e_uid, c_uid = EnumName.Map.choose possible_c_uids in - let payload = - Option.map (translate_expr scope inside_definition_of ctxt) payload - in + let payload = Option.map rec_helper payload in Expr.einj - (match payload with + ~e:(match payload with | Some e' -> e' | None -> Expr.elit LUnit mark_constructor) - c_uid e_uid emark - | [enum] -> ( + ~cons:c_uid ~name:e_uid emark + | path_enum -> ( + let path, enum = match List.rev path_enum with + | enum :: rpath -> List.rev rpath, enum + | _ -> assert false + in try - (* The path has been fully qualified *) + let ctxt = Name_resolution.module_ctx ctxt path in + let possible_c_uids = get_possible_c_uids ctxt in + (* The path has been qualified *) let e_uid = Name_resolution.get_enum ctxt enum in try let c_uid = EnumName.Map.find e_uid possible_c_uids in let payload = - Option.map (translate_expr scope inside_definition_of ctxt) payload + Option.map rec_helper payload in Expr.einj - (match payload with + ~e:(match payload with | Some e' -> e' | None -> Expr.elit LUnit mark_constructor) - c_uid e_uid emark + ~cons:c_uid ~name:e_uid emark with Not_found -> Message.raise_spanned_error pos "Enum %s does not contain case %s" (Mark.remove enum) constructor with Not_found -> Message.raise_spanned_error (Mark.get enum) - "Enum %s has not been defined before" (Mark.remove enum)) - | _ -> - Message.raise_spanned_error pos "Qualified paths are not supported yet") + "Enum %s has not been defined" (Mark.remove enum))) | MatchWith (e1, (cases, _cases_pos)) -> - let e1 = translate_expr scope inside_definition_of ctxt e1 in + let e1 = rec_helper e1 in let cases_d, e_uid = disambiguate_match_and_build_expression scope inside_definition_of ctxt + local_vars cases in - Expr.ematch e1 e_uid cases_d emark + Expr.ematch ~e:e1 ~name:e_uid ~cases:cases_d emark | TestMatchCase (e1, pattern) -> (match snd (Mark.remove pattern) with | None -> () @@ -580,18 +607,17 @@ let rec translate_expr (EnumName.Map.find enum_uid ctxt.enums) in Expr.ematch - (translate_expr scope inside_definition_of ctxt e1) - enum_uid cases emark + ~e:(rec_helper e1) + ~name:enum_uid ~cases:cases emark | ArrayLit es -> Expr.earray (List.map rec_helper es) emark | CollectionOp (((S.Filter { f } | S.Map { f }) as op), collection) -> let collection = rec_helper collection in - let param, predicate = f in - let ctxt, param = - Name_resolution.add_def_local_var ctxt (Mark.remove param) - in + let param_name, predicate = f in + let param = Var.make (Mark.remove param_name) in + let local_vars = Ident.Map.add (Mark.remove param_name) param local_vars in let f_pred = Expr.make_abs [| param |] - (translate_expr scope inside_definition_of ctxt predicate) + (rec_helper ~local_vars predicate) [TAny, pos] pos in @@ -605,18 +631,17 @@ let rec translate_expr emark) [f_pred; collection] emark | CollectionOp - (S.AggregateArgExtremum { max; default; f = param, predicate }, collection) + (S.AggregateArgExtremum { max; default; f = param_name, predicate }, collection) -> let default = rec_helper default in let pos_dft = Expr.pos default in let collection = rec_helper collection in - let ctxt, param = - Name_resolution.add_def_local_var ctxt (Mark.remove param) - in + let param = Var.make (Mark.remove param_name) in + let local_vars = Ident.Map.add (Mark.remove param_name) param local_vars in let cmp_op = if max then Op.Gt else Op.Lt in let f_pred = Expr.make_abs [| param |] - (translate_expr scope inside_definition_of ctxt predicate) + (rec_helper ~local_vars predicate) [TAny, pos] pos in @@ -655,16 +680,15 @@ let rec translate_expr in let init = Expr.elit (LBool init) emark in let param0, predicate = predicate in - let ctxt, param = - Name_resolution.add_def_local_var ctxt (Mark.remove param0) - in + let param = Var.make (Mark.remove param0) in + let local_vars = Ident.Map.add (Mark.remove param0) param local_vars in let f = let acc_var = Var.make "acc" in let acc = Expr.make_var acc_var (Untyped { pos = Mark.get param0 }) in Expr.eabs (Expr.bind [| acc_var; param |] (Expr.eapp (translate_binop op pos) - [acc; translate_expr scope inside_definition_of ctxt predicate] + [acc; rec_helper ~local_vars predicate] emark)) [TAny, pos; TAny, pos] emark @@ -674,7 +698,7 @@ let rec translate_expr [f; init; collection] emark | CollectionOp (AggregateExtremum { max; default }, collection) -> let collection = rec_helper collection in - let default = translate_expr scope inside_definition_of ctxt default in + let default = rec_helper default in let op = translate_binop (if max then S.Gt KPoly else S.Lt KPoly) pos in let op_f = (* fun x1 x2 -> if op x1 x2 then x1 else x2 *) @@ -729,7 +753,7 @@ let rec translate_expr let acc_var = Var.make "acc" in let acc = Expr.make_var acc_var emark in let f_body = - let member = translate_expr scope inside_definition_of ctxt member in + let member = rec_helper member in Expr.eapp (Expr.eop Or [TLit TBool, pos; TLit TBool, pos] emark) [ @@ -763,13 +787,14 @@ and disambiguate_match_and_build_expression (scope : ScopeName.t option) (inside_definition_of : Ast.ScopeDef.t Mark.pos option) (ctxt : Name_resolution.context) - (cases : Surface.Ast.match_case Mark.pos list) : + (local_vars : Ast.expr Var.t Ident.Map.t) + (cases : S.match_case Mark.pos list) : Ast.expr boxed EnumConstructor.Map.t * EnumName.t = - let create_var = function - | None -> ctxt, Var.make "_" + let create_var local_vars = function + | None -> local_vars, Var.make "_" | Some param -> - let ctxt, param_var = Name_resolution.add_def_local_var ctxt param in - ctxt, param_var + let param_var = Var.make param in + Ident.Map.add param param_var local_vars, param_var in let bind_case_body (c_uid : EnumConstructor.t) @@ -786,13 +811,13 @@ and disambiguate_match_and_build_expression in let bind_match_cases (cases_d, e_uid, curr_index) (case, case_pos) = match case with - | Surface.Ast.MatchCase case -> + | S.MatchCase case -> let constructor, binding = - Mark.remove case.Surface.Ast.match_case_pattern + Mark.remove case.S.match_case_pattern in let e_uid', c_uid = disambiguate_constructor ctxt constructor - (Mark.get case.Surface.Ast.match_case_pattern) + (Mark.get case.S.match_case_pattern) in let e_uid = match e_uid with @@ -801,7 +826,7 @@ and disambiguate_match_and_build_expression if e_uid = e_uid' then e_uid else Message.raise_spanned_error - (Mark.get case.Surface.Ast.match_case_pattern) + (Mark.get case.S.match_case_pattern) "This case matches a constructor of enumeration %a but previous \ case were matching constructors of enumeration %a" EnumName.format e_uid EnumName.format e_uid' @@ -813,17 +838,17 @@ and disambiguate_match_and_build_expression [None, Mark.get case.match_case_expr; None, Expr.pos e_case] "The constructor %a has been matched twice:" EnumConstructor.format c_uid); - let ctxt, param_var = create_var (Option.map Mark.remove binding) in + let local_vars, param_var = create_var local_vars (Option.map Mark.remove binding) in let case_body = - translate_expr scope inside_definition_of ctxt - case.Surface.Ast.match_case_expr + translate_expr scope inside_definition_of ctxt local_vars + case.S.match_case_expr in let e_binder = Expr.bind [| param_var |] case_body in let case_expr = bind_case_body c_uid e_uid ctxt case_body e_binder in ( EnumConstructor.Map.add c_uid case_expr cases_d, Some e_uid, curr_index + 1 ) - | Surface.Ast.WildCard match_case_expr -> ( + | S.WildCard match_case_expr -> ( let nb_cases = List.length cases in let raise_wildcard_not_last_case_err () = Message.raise_multispanned_error @@ -867,9 +892,9 @@ and disambiguate_match_and_build_expression ... | CaseN -> wildcard_payload *) (* Creates the wildcard payload *) - let ctxt, payload_var = create_var None in + let local_vars, payload_var = create_var local_vars None in let case_body = - translate_expr scope inside_definition_of ctxt match_case_expr + translate_expr scope inside_definition_of ctxt local_vars match_case_expr in let e_binder = Expr.bind [| payload_var |] case_body in @@ -941,13 +966,13 @@ let rec arglist_eq_check pos_decl pos_def pdecl pdefs = let process_rule_parameters ctxt (def_key : Ast.ScopeDef.t Mark.pos) - (def : Surface.Ast.definition) : - Name_resolution.context + (def : S.definition) : + Ast.expr Var.t Ident.Map.t * (Ast.expr Var.t Mark.pos * typ) list Mark.pos option = let decl_name, decl_pos = def_key in let declared_params = Name_resolution.get_params ctxt decl_name in match declared_params, def.S.definition_parameter with - | None, None -> ctxt, None + | None, None -> Ident.Map.empty, None | None, Some (_, pos) -> Message.raise_multispanned_error [ @@ -960,25 +985,27 @@ let process_rule_parameters [ Some "Arguments declared here", pos; ( Some "Definition missing the arguments", - Mark.get def.Surface.Ast.definition_name ); + Mark.get def.S.definition_name ); ] "This definition for %a is missing the arguments" Ast.ScopeDef.format decl_name | Some (pdecl, pos_decl), Some (pdefs, pos_def) -> arglist_eq_check pos_decl pos_def (List.map fst pdecl) pdefs; - let ctxt, params = + let local_vars, params = List.fold_left_map - (fun ctxt ((lbl, pos), ty) -> - let ctxt, v = Name_resolution.add_def_local_var ctxt lbl in - ctxt, ((v, pos), ty)) - ctxt pdecl + (fun local_vars ((lbl, pos), ty) -> + let v = Var.make lbl in + let local_vars = Ident.Map.add lbl v local_vars in + local_vars, ((v, pos), ty)) + Ident.Map.empty pdecl in - ctxt, Some (params, pos_def) + local_vars, Some (params, pos_def) (** Translates a surface definition into condition into a desugared {!type: Ast.rule} *) let process_default (ctxt : Name_resolution.context) + (local_vars : Ast.expr Var.t Ident.Map.t) (scope : ScopeName.t) (def_key : Ast.ScopeDef.t Mark.pos) (rule_id : RuleName.t) @@ -986,15 +1013,15 @@ let process_default (precond : Ast.expr boxed option) (exception_situation : Ast.exception_situation) (label_situation : Ast.label_situation) - (just : Surface.Ast.expression option) - (cons : Surface.Ast.expression) : Ast.rule = + (just : S.expression option) + (cons : S.expression) : Ast.rule = let just = match just with - | Some just -> Some (translate_expr (Some scope) (Some def_key) ctxt just) + | Some just -> Some (translate_expr (Some scope) (Some def_key) ctxt local_vars just) | None -> None in let just = merge_conditions precond just (Mark.get def_key) in - let cons = translate_expr (Some scope) (Some def_key) ctxt cons in + let cons = translate_expr (Some scope) (Some def_key) ctxt local_vars cons in { Ast.rule_just = just; rule_cons = cons; @@ -1011,7 +1038,7 @@ let process_def (scope_uid : ScopeName.t) (ctxt : Name_resolution.context) (prgm : Ast.program) - (def : Surface.Ast.definition) : Ast.program = + (def : S.definition) : Ast.program = let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in let scope_ctxt = ScopeName.Map.find scope_uid ctxt.scopes in let def_key = @@ -1024,7 +1051,8 @@ let process_def Ast.ScopeDef.Map.find def_key scope_ctxt.scope_defs_contexts in (* We add to the name resolution context the name of the parameter variable *) - let new_ctxt, param_uids = + Message.emit_debug "PROCESS_DEF %a@!" Ast.ScopeDef.format def_key; + let local_vars, param_uids = process_rule_parameters ctxt (Mark.copy def.definition_name def_key) def in let scope_updated = @@ -1038,7 +1066,7 @@ let process_def | None -> Ast.Unlabeled in let exception_situation = - match def.Surface.Ast.definition_exception_to with + match def.S.definition_exception_to with | NotAnException -> Ast.BaseCase | UnlabeledException -> ( match scope_def_ctxt.default_exception_rulename with @@ -1064,7 +1092,7 @@ let process_def scope_def with scope_def_rules = RuleName.Map.add rule_name - (process_default new_ctxt scope_uid + (process_default ctxt local_vars scope_uid (def_key, Mark.get def.definition_name) rule_name param_uids precond exception_situation label_situation def.definition_condition def.definition_expr) @@ -1082,14 +1110,14 @@ let process_def ScopeName.Map.add scope_uid scope_updated prgm.program_scopes; } -(** Translates a {!type: Surface.Ast.rule} from the surface language *) +(** Translates a {!type: S.rule} from the surface language *) let process_rule (precond : Ast.expr boxed option) (scope : ScopeName.t) (ctxt : Name_resolution.context) (prgm : Ast.program) - (rule : Surface.Ast.rule) : Ast.program = - let def = Surface.Ast.rule_to_def rule in + (rule : S.rule) : Ast.program = + let def = S.rule_to_def rule in process_def precond scope ctxt prgm def (** Translates assertions *) @@ -1098,17 +1126,17 @@ let process_assert (scope_uid : ScopeName.t) (ctxt : Name_resolution.context) (prgm : Ast.program) - (ass : Surface.Ast.assertion) : Ast.program = + (ass : S.assertion) : Ast.program = let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in let ass = - translate_expr (Some scope_uid) None ctxt - (match ass.Surface.Ast.assertion_condition with - | None -> ass.Surface.Ast.assertion_content + translate_expr (Some scope_uid) None ctxt Ident.Map.empty + (match ass.S.assertion_condition with + | None -> ass.S.assertion_content | Some cond -> - ( Surface.Ast.IfThenElse + ( S.IfThenElse ( cond, - ass.Surface.Ast.assertion_content, - Mark.copy cond (Surface.Ast.Literal (Surface.Ast.LBool true)) ), + ass.S.assertion_content, + Mark.copy cond (S.Literal (S.LBool true)) ), Mark.get cond )) in let assertion = @@ -1138,23 +1166,23 @@ let process_assert (** Translates a surface definition, rule or assertion *) let process_scope_use_item - (precond : Surface.Ast.expression option) + (precond : S.expression option) (scope : ScopeName.t) (ctxt : Name_resolution.context) (prgm : Ast.program) - (item : Surface.Ast.scope_use_item Mark.pos) : Ast.program = - let precond = Option.map (translate_expr (Some scope) None ctxt) precond in + (item : S.scope_use_item Mark.pos) : Ast.program = + let precond = Option.map (translate_expr (Some scope) None ctxt Ident.Map.empty) precond in match Mark.remove item with - | Surface.Ast.Rule rule -> process_rule precond scope ctxt prgm rule - | Surface.Ast.Definition def -> process_def precond scope ctxt prgm def - | Surface.Ast.Assertion ass -> process_assert precond scope ctxt prgm ass - | Surface.Ast.DateRounding (r, _) -> + | S.Rule rule -> process_rule precond scope ctxt prgm rule + | S.Definition def -> process_def precond scope ctxt prgm def + | S.Assertion ass -> process_assert precond scope ctxt prgm ass + | S.DateRounding (r, _) -> let scope_uid = scope in let scope : Ast.scope = ScopeName.Map.find scope_uid prgm.program_scopes in let r = match r with - | Surface.Ast.Increasing -> Ast.Increasing - | Surface.Ast.Decreasing -> Ast.Decreasing + | S.Increasing -> Ast.Increasing + | S.Decreasing -> Ast.Decreasing in let new_scope = match @@ -1188,18 +1216,18 @@ let process_scope_use_item let check_unlabeled_exception (scope : ScopeName.t) (ctxt : Name_resolution.context) - (item : Surface.Ast.scope_use_item Mark.pos) : unit = + (item : S.scope_use_item Mark.pos) : unit = let scope_ctxt = ScopeName.Map.find scope ctxt.scopes in match Mark.remove item with - | Surface.Ast.Rule _ | Surface.Ast.Definition _ -> ( + | S.Rule _ | S.Definition _ -> ( let def_key, exception_to = match Mark.remove item with - | Surface.Ast.Rule rule -> + | S.Rule rule -> ( Name_resolution.get_def_key (Mark.remove rule.rule_name) rule.rule_state scope ctxt (Mark.get rule.rule_name), rule.rule_exception_to ) - | Surface.Ast.Definition def -> + | S.Definition def -> ( Name_resolution.get_def_key (Mark.remove def.definition_name) def.definition_state scope ctxt @@ -1212,10 +1240,10 @@ let check_unlabeled_exception Ast.ScopeDef.Map.find def_key scope_ctxt.scope_defs_contexts in match exception_to with - | Surface.Ast.NotAnException | Surface.Ast.ExceptionToLabel _ -> () + | S.NotAnException | S.ExceptionToLabel _ -> () (* If this is an unlabeled exception, we check that it has a unique default definition *) - | Surface.Ast.UnlabeledException -> ( + | S.UnlabeledException -> ( match scope_def_ctxt.default_exception_rulename with | None -> Message.raise_spanned_error (Mark.get item) @@ -1233,7 +1261,7 @@ let check_unlabeled_exception let process_scope_use (ctxt : Name_resolution.context) (prgm : Ast.program) - (use : Surface.Ast.scope_use) : Ast.program = + (use : S.scope_use) : Ast.program = let scope_uid = Name_resolution.get_scope ctxt use.scope_use_name in (* Make sure the scope exists *) let prgm = @@ -1261,16 +1289,17 @@ let process_topdef let expr_opt = match def.S.topdef_expr, def.S.topdef_args with | None, _ -> None - | Some e, None -> Some (Expr.unbox_closed (translate_expr None None ctxt e)) + | Some e, None -> Some (Expr.unbox_closed (translate_expr None None ctxt Ident.Map.empty e)) | Some e, Some (args, _) -> - let ctxt, args_tys = + let local_vars, args_tys = List.fold_left_map - (fun ctxt ((lbl, pos), ty) -> - let ctxt, v = Name_resolution.add_def_local_var ctxt lbl in - ctxt, ((v, pos), ty)) - ctxt args + (fun local_vars ((lbl, pos), ty) -> + let v = Var.make lbl in + let local_vars = Ident.Map.add lbl v local_vars in + local_vars, ((v, pos), ty)) + Ident.Map.empty args in - let body = translate_expr None None ctxt e in + let body = translate_expr None None ctxt local_vars e in let args, tys = List.split args_tys in let e = Expr.make_abs @@ -1303,16 +1332,16 @@ let process_topdef in { prgm with Ast.program_topdefs } -let attribute_to_io (attr : Surface.Ast.scope_decl_context_io) : Ast.io = +let attribute_to_io (attr : S.scope_decl_context_io) : Ast.io = { Ast.io_output = attr.scope_decl_context_io_output; Ast.io_input = Mark.map (fun io -> match io with - | Surface.Ast.Input -> Runtime.OnlyInput - | Surface.Ast.Internal -> Runtime.NoInput - | Surface.Ast.Context -> Runtime.Reentrant) + | S.Input -> Runtime.OnlyInput + | S.Internal -> Runtime.NoInput + | S.Context -> Runtime.Reentrant) attr.scope_decl_context_io_input; } @@ -1370,7 +1399,8 @@ let init_scope_defs (scope_def_map, 0) states in scope_def) - | Name_resolution.SubScope (v0, subscope_uid) -> + | Name_resolution.SubScope (v0, (path, subscope_uid)) -> + let ctxt = Name_resolution.module_ctx ctxt path in let sub_scope_def = ScopeName.Map.find subscope_uid ctxt.Name_resolution.scopes in @@ -1401,9 +1431,9 @@ let init_scope_defs (** Main function of this module *) let translate_program (ctxt : Name_resolution.context) - (prgm : Surface.Ast.program) : Ast.program = - let empty_prgm = - let program_scopes = + (surface : S.program) : Ast.program = + let desugared = + let get_program_scopes ctxt = ScopeName.Map.mapi (fun s_uid s_context -> let scope_vars = @@ -1412,8 +1442,8 @@ let translate_program match v with | Name_resolution.SubScope _ -> acc | Name_resolution.ScopeVar v -> ( - let v_sig = ScopeVar.Map.find v ctxt.var_typs in - match v_sig.var_sig_states_list with + let v_sig = ScopeVar.Map.find v ctxt.Name_resolution.var_typs in + match v_sig.Name_resolution.var_sig_states_list with | [] -> ScopeVar.Map.add v Ast.WholeVar acc | states -> ScopeVar.Map.add v (Ast.States states) acc)) s_context.Name_resolution.var_idmap ScopeVar.Map.empty @@ -1438,57 +1468,90 @@ let translate_program }) ctxt.Name_resolution.scopes in - let translate_type t = Name_resolution.process_type ctxt t in - { - Ast.program_ctx = - { - ctx_structs = ctxt.Name_resolution.structs; - ctx_enums = ctxt.Name_resolution.enums; - ctx_scopes = - Ident.Map.fold - (fun _ def acc -> - match def with - | Name_resolution.TScope (scope, scope_out_struct) -> - ScopeName.Map.add scope scope_out_struct acc - | _ -> acc) - ctxt.Name_resolution.typedefs ScopeName.Map.empty; - ctx_struct_fields = ctxt.Name_resolution.field_idmap; - ctx_modules = - List.fold_left - (fun map (path, def) -> - match def with - | Surface.Ast.Topdef { topdef_name; topdef_type; _ }, _pos -> - Qident.Map.add - (path, Mark.remove topdef_name) - (translate_type topdef_type) - map - | (ScopeDecl _ | StructDecl _ | EnumDecl _), _ (* as e *) -> - map (* assert false (\* TODO *\) *) - | ScopeUse _, _ -> assert false) - Qident.Map.empty prgm.Surface.Ast.program_interfaces; - }; - Ast.program_topdefs = TopdefName.Map.empty; - Ast.program_scopes; - } + let rec make_ctx ctxt = + let submodules = + ModuleName.Map.map make_ctx ctxt.Name_resolution.modules; + in + { + Ast.program_ctx = + { + (* After name resolution, type definitions (structs and enums) are exposed at toplevel for easier lookup, but their paths need to remain available for printing and later passes *) + ctx_structs = + ModuleName.Map.fold (fun modname prg acc -> + StructName.Map.union (fun _ _ _ -> assert false) acc + (StructName.Map.map + (fun (path, def) -> (modname, Pos.no_pos) :: path, def) + prg.Ast.program_ctx.ctx_structs)) + submodules + (StructName.Map.map (fun def -> [], def) ctxt.Name_resolution.structs); + ctx_enums = + ModuleName.Map.fold (fun modname prg acc -> + EnumName.Map.union (fun _ _ _ -> assert false) acc + (EnumName.Map.map + (fun (path, def) -> (modname, Pos.no_pos) :: path, def) + prg.Ast.program_ctx.ctx_enums)) + submodules + (EnumName.Map.map (fun def -> [], def) ctxt.Name_resolution.enums); + ctx_scopes = + Ident.Map.fold + (fun _ def acc -> + match def with + | Name_resolution.TScope (scope, scope_info) -> + ScopeName.Map.add scope scope_info acc + | _ -> acc) + ctxt.Name_resolution.typedefs ScopeName.Map.empty; + ctx_struct_fields = ctxt.Name_resolution.field_idmap; + ctx_topdefs = ctxt.Name_resolution.topdef_types; + ctx_modules = ModuleName.Map.map (fun s -> s.Ast.program_ctx) submodules; + }; + Ast.program_topdefs = TopdefName.Map.empty; + Ast.program_scopes = get_program_scopes ctxt; + Ast.program_modules = submodules; + } + in + make_ctx ctxt in - let rec processer_structure + let process_code_block ctxt prgm block = + List.fold_left + (fun prgm item -> + match Mark.remove item with + | S.ScopeUse use -> process_scope_use ctxt prgm use + | S.Topdef def -> process_topdef ctxt prgm def + | S.ScopeDecl _ | S.StructDecl _ + | S.EnumDecl _ -> + prgm) + prgm block + in + let rec process_structure (prgm : Ast.program) - (item : Surface.Ast.law_structure) : Ast.program = + (item : S.law_structure) : Ast.program = match item with - | LawHeading (_, children) -> + | S.LawHeading (_, children) -> List.fold_left - (fun prgm child -> processer_structure prgm child) + (fun prgm child -> process_structure prgm child) prgm children - | CodeBlock (block, _, _) -> - List.fold_left - (fun prgm item -> - match Mark.remove item with - | Surface.Ast.ScopeUse use -> process_scope_use ctxt prgm use - | Surface.Ast.Topdef def -> process_topdef ctxt prgm def - | Surface.Ast.ScopeDecl _ | Surface.Ast.StructDecl _ - | Surface.Ast.EnumDecl _ -> - prgm) - prgm block - | LawInclude _ | LawText _ -> prgm + | S.CodeBlock (block, _, _) -> + process_code_block ctxt prgm block + | S.LawInclude _ | S.LawText _ -> prgm in - List.fold_left processer_structure empty_prgm prgm.program_items + Message.emit_debug "DESUGARED → prog scopes: %a@ modules: %a" + (ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space) desugared.Ast.program_scopes + (ModuleName.Map.format + (fun fmt prg -> ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space fmt prg.Ast.program_scopes)) desugared.Ast.program_modules; + let desugared = + List.fold_left (fun acc (id, intf) -> + let modul = ModuleName.Map.find id acc.Ast.program_modules in + let modul = process_code_block (Name_resolution.module_ctx ctxt [id, Pos.no_pos]) modul intf in + { acc with program_modules = + ModuleName.Map.add id modul acc.program_modules }) + desugared + surface.S.program_modules + in + let desugared = + List.fold_left process_structure desugared surface.S.program_items + in + Message.emit_debug "DESUGARED2 → prog scopes: %a@ modules: %a" + (ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space) desugared.Ast.program_scopes + (ModuleName.Map.format + (fun fmt prg -> ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space fmt prg.Ast.program_scopes)) desugared.Ast.program_modules; + desugared diff --git a/compiler/desugared/linting.ml b/compiler/desugared/linting.ml index feb5ba7b..68e1d03f 100644 --- a/compiler/desugared/linting.ml +++ b/compiler/desugared/linting.ml @@ -108,7 +108,7 @@ let detect_unused_struct_fields (p : program) : unit = ~f:(fun struct_fields_used e -> let rec structs_fields_used_expr e struct_fields_used = match Mark.remove e with - | EDStructAccess { name_opt = Some name; e = e_struct; field } -> + | EDStructAccess { name_opt = Some name; e = e_struct; field; path = _ } -> let field = StructName.Map.find name (Ident.Map.find field p.program_ctx.ctx_struct_fields) @@ -135,8 +135,9 @@ let detect_unused_struct_fields (p : program) : unit = p.program_ctx.ctx_scopes StructField.Set.empty in StructName.Map.iter - (fun s_name fields -> - if + (fun s_name (path, fields) -> + if path <> [] then () + else if (not (StructField.Map.is_empty fields)) && StructField.Map.for_all (fun field _ -> @@ -190,8 +191,9 @@ let detect_unused_enum_constructors (p : program) : unit = ~init:EnumConstructor.Set.empty p in EnumName.Map.iter - (fun e_name constructors -> - if + (fun e_name (path, constructors) -> + if path <> [] then () + else if EnumConstructor.Map.for_all (fun cons _ -> not (EnumConstructor.Set.mem cons enum_constructors_used)) diff --git a/compiler/desugared/name_resolution.ml b/compiler/desugared/name_resolution.ml index ccee3c34..4684e853 100644 --- a/compiler/desugared/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -32,7 +32,7 @@ type scope_def_context = { type scope_var_or_subscope = | ScopeVar of ScopeVar.t - | SubScope of SubScopeName.t * ScopeName.t + | SubScope of SubScopeName.t * (path * ScopeName.t) type scope_context = { var_idmap : scope_var_or_subscope Ident.Map.t; @@ -65,13 +65,10 @@ type var_sig = { type typedef = | TStruct of StructName.t | TEnum of EnumName.t - | TScope of ScopeName.t * scope_out_struct + | TScope of ScopeName.t * scope_info (** Implicitly defined output struct *) type context = { - local_var_idmap : Ast.expr Var.t Ident.Map.t; - (** Inside a definition, local variables can be introduced by functions - arguments or pattern matching *) typedefs : typedef Ident.Map.t; (** Gathers the names of the scopes, structs and enums *) field_idmap : StructField.t StructName.Map.t Ident.Map.t; @@ -82,11 +79,13 @@ type context = { between different enums *) scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) topdefs : TopdefName.t Ident.Map.t; (** Global definitions *) + topdef_types : typ TopdefName.Map.t; structs : struct_context StructName.Map.t; (** For each struct, its context *) enums : enum_context EnumName.Map.t; (** For each enum, its context *) var_typs : var_sig ScopeVar.Map.t; (** The signatures of each scope variable declared *) + modules : context ModuleName.Map.t; } (** Main context used throughout {!module: Surface.Desugaring} *) @@ -213,7 +212,7 @@ let get_struct ctxt id = None, Mark.get id; Some "Enum defined at", Mark.get (EnumName.get_info eid); ] - "Expecting an struct, but found an enum" + "Expecting a struct, but found an enum" | exception Not_found -> Message.raise_spanned_error (Mark.get id) "No struct named %s found" (Mark.remove id) @@ -239,6 +238,16 @@ let get_scope ctxt id = Message.raise_spanned_error (Mark.get id) "No scope named %s found" (Mark.remove id) +let rec module_ctx ctxt path = match path with + | [] -> ctxt + | (modname, mpos) :: path -> + (match ModuleName.Map.find_opt modname ctxt.modules with + | None -> + Message.raise_spanned_error mpos + "Module %a not found" ModuleName.format modname + | Some ctxt -> + module_ctx ctxt path) + (** {1 Declarations pass} *) (** Process a subscope declaration *) @@ -247,9 +256,9 @@ let process_subscope_decl (ctxt : context) (decl : Surface.Ast.scope_decl_context_scope) : context = let name, name_pos = decl.scope_decl_context_scope_name in - let subscope, s_pos = decl.scope_decl_context_scope_sub_scope in + let (path, subscope), s_pos = decl.scope_decl_context_scope_sub_scope in let scope_ctxt = ScopeName.Map.find scope ctxt.scopes in - match Ident.Map.find_opt subscope scope_ctxt.var_idmap with + match Ident.Map.find_opt (Mark.remove subscope) scope_ctxt.var_idmap with | Some use -> let info = match use with @@ -258,18 +267,20 @@ let process_subscope_decl in Message.raise_multispanned_error [Some "first use", Mark.get info; Some "second use", s_pos] - "Subscope name @{\"%s\"@} already used" subscope + "Subscope name @{\"%s\"@} already used" + (Mark.remove subscope) | None -> let sub_scope_uid = SubScopeName.fresh (name, name_pos) in let original_subscope_uid = - get_scope ctxt decl.scope_decl_context_scope_sub_scope + let ctxt = module_ctx ctxt path in + get_scope ctxt subscope in let scope_ctxt = { scope_ctxt with var_idmap = Ident.Map.add name - (SubScope (sub_scope_uid, original_subscope_uid)) + (SubScope (sub_scope_uid, (path, original_subscope_uid))) scope_ctxt.var_idmap; sub_scopes = ScopeName.Set.add original_subscope_uid scope_ctxt.sub_scopes; @@ -305,18 +316,23 @@ let rec process_base_typ | Surface.Ast.Text -> raise_unsupported_feature "text type" typ_pos | Surface.Ast.Named ([], (ident, _pos)) -> ( match Ident.Map.find_opt ident ctxt.typedefs with - | Some (TStruct s_uid) -> TStruct s_uid, typ_pos - | Some (TEnum e_uid) -> TEnum e_uid, typ_pos + | Some (TStruct s_uid) -> TStruct ( s_uid), typ_pos + | Some (TEnum e_uid) -> TEnum ( e_uid), typ_pos | Some (TScope (_, scope_str)) -> - TStruct scope_str.out_struct_name, typ_pos + TStruct ( scope_str.out_struct_name), typ_pos | None -> Message.raise_spanned_error typ_pos "Unknown type @{\"%s\"@}, not a struct or enum previously \ declared" ident) - | Surface.Ast.Named (_path, (_ident, _pos)) -> - Message.raise_spanned_error typ_pos - "Qualified paths are not supported yet") + | Surface.Ast.Named ((modul, mpos)::path, id) -> + match ModuleName.Map.find_opt modul ctxt.modules with + | None -> + Message.raise_spanned_error mpos + "This refers to module %a, which was not found" + ModuleName.format modul + | Some mod_ctxt -> + process_base_typ mod_ctxt Surface.Ast.(Data (Primitive (Named (path, id))), typ_pos)) (** Process a type (function or not) *) let process_type (ctxt : context) ((naked_typ, typ_pos) : Surface.Ast.typ) : typ @@ -405,18 +421,6 @@ let process_data_decl ctxt.var_typs; } -(** Adds a binding to the context *) -let add_def_local_var (ctxt : context) (name : Ident.t) : - context * Ast.expr Var.t = - let local_var_uid = Var.make name in - let ctxt = - { - ctxt with - local_var_idmap = Ident.Map.add name local_var_uid ctxt.local_var_idmap; - } - in - ctxt, local_var_uid - (** Process a struct declaration *) let process_struct_decl (ctxt : context) (sdecl : Surface.Ast.struct_decl) : context = @@ -584,8 +588,8 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : Ident.Map.update (Mark.remove decl.scope_decl_name) (function - | Some (TScope (scope, { out_struct_name; _ })) -> - Some (TScope (scope, { out_struct_name; out_struct_fields })) + | Some (TScope (scope, { in_struct_name; out_struct_name; _ })) -> + Some (TScope (scope, { in_struct_name; out_struct_name; out_struct_fields; })) | _ -> assert false) ctxt.typedefs in @@ -617,7 +621,8 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : raise_already_defined_error (typedef_info use) name pos "scope") (Ident.Map.find_opt name ctxt.typedefs); let scope_uid = ScopeName.fresh (name, pos) in - let out_struct_uid = StructName.fresh (name, pos) in + let in_struct_name = StructName.fresh (name ^ "_in", pos) in + let out_struct_name = StructName.fresh (name, pos) in { ctxt with typedefs = @@ -625,7 +630,8 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : (TScope ( scope_uid, { - out_struct_name = out_struct_uid; + in_struct_name; + out_struct_name; out_struct_fields = ScopeVar.Map.empty; } )) ctxt.typedefs; @@ -675,7 +681,9 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : "toplevel definition") (Ident.Map.find_opt name ctxt.topdefs); let uid = TopdefName.fresh def.topdef_name in - { ctxt with topdefs = Ident.Map.add name uid ctxt.topdefs } + { ctxt with + topdefs = Ident.Map.add name uid ctxt.topdefs; + topdef_types = TopdefName.Map.add uid (process_type ctxt def.topdef_type) ctxt.topdef_types } (** Process a code item that is a declaration *) let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : @@ -689,25 +697,25 @@ let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : (** Process a code block *) let process_code_block + (process_item : context -> Surface.Ast.code_item Mark.pos -> context) (ctxt : context) - (block : Surface.Ast.code_block) - (process_item : context -> Surface.Ast.code_item Mark.pos -> context) : + (block : Surface.Ast.code_block) : context = List.fold_left (fun ctxt decl -> process_item ctxt decl) ctxt block (** Process a law structure, only considering the code blocks *) let rec process_law_structure + (process_item : context -> Surface.Ast.code_item Mark.pos -> context) (ctxt : context) - (s : Surface.Ast.law_structure) - (process_item : context -> Surface.Ast.code_item Mark.pos -> context) : + (s : Surface.Ast.law_structure) : context = match s with | Surface.Ast.LawHeading (_, children) -> List.fold_left - (fun ctxt child -> process_law_structure ctxt child process_item) + (fun ctxt child -> process_law_structure process_item ctxt child) ctxt children | Surface.Ast.CodeBlock (block, _, _) -> - process_code_block ctxt block process_item + process_code_block process_item ctxt block | Surface.Ast.LawInclude _ | Surface.Ast.LawText _ -> ctxt (** {1 Scope uses pass} *) @@ -750,7 +758,7 @@ let get_def_key ScopeVar.format x_uid else None ) | [y; x] -> - let (subscope_uid, subscope_real_uid) : SubScopeName.t * ScopeName.t = + let (subscope_uid, (path, subscope_real_uid)) : SubScopeName.t * (path * ScopeName.t) = match Ident.Map.find_opt (Mark.remove y) scope_ctxt.var_idmap with | Some (SubScope (v, u)) -> v, u | Some _ -> @@ -761,7 +769,10 @@ let get_def_key Message.raise_spanned_error pos "No definition found for subscope %a" Print.lit_style (Mark.remove y) in - let x_uid = get_var_uid subscope_real_uid ctxt x in + let x_uid = + let ctxt = module_ctx ctxt path in + get_var_uid subscope_real_uid ctxt x + in Ast.ScopeDef.SubScopeVar (subscope_uid, x_uid, pos) | _ -> Message.raise_spanned_error pos @@ -906,34 +917,51 @@ let process_use_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : (** {1 API} *) +let empty_ctxt = + { + typedefs = Ident.Map.empty; + scopes = ScopeName.Map.empty; + topdefs = Ident.Map.empty; + topdef_types = TopdefName.Map.empty; + var_typs = ScopeVar.Map.empty; + structs = StructName.Map.empty; + field_idmap = Ident.Map.empty; + enums = EnumName.Map.empty; + constructor_idmap = Ident.Map.empty; + modules = ModuleName.Map.empty; + } + +let import_module modules (name, intf) = + let ctxt = { empty_ctxt with modules } in + let ctxt = + List.fold_left process_name_item ctxt intf + in + let ctxt = + List.fold_left process_decl_item ctxt intf + in + let ctxt = { ctxt with modules = empty_ctxt.modules } in + (* No submodules at the moment, a module may use the ones loaded before it, but doesn't reexport them *) + ModuleName.Map.add name ctxt modules + (** Derive the context from metadata, in one pass over the declarations *) let form_context (prgm : Surface.Ast.program) : context = - let empty_ctxt = - { - local_var_idmap = Ident.Map.empty; - typedefs = Ident.Map.empty; - scopes = ScopeName.Map.empty; - topdefs = Ident.Map.empty; - var_typs = ScopeVar.Map.empty; - structs = StructName.Map.empty; - field_idmap = Ident.Map.empty; - enums = EnumName.Map.empty; - constructor_idmap = Ident.Map.empty; - } + let modules = + List.fold_left import_module ModuleName.Map.empty prgm.program_modules in + let ctxt = { empty_ctxt with modules } in let ctxt = List.fold_left - (fun ctxt item -> process_law_structure ctxt item process_name_item) - empty_ctxt prgm.program_items - in - let ctxt = - List.fold_left - (fun ctxt item -> process_law_structure ctxt item process_decl_item) + (process_law_structure process_name_item) ctxt prgm.program_items in let ctxt = List.fold_left - (fun ctxt item -> process_law_structure ctxt item process_use_item) + (process_law_structure process_decl_item) + ctxt prgm.program_items + in + let ctxt = + List.fold_left + (process_law_structure process_use_item) ctxt prgm.program_items in ctxt diff --git a/compiler/desugared/name_resolution.mli b/compiler/desugared/name_resolution.mli index bfb011b5..f96b069c 100644 --- a/compiler/desugared/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -32,7 +32,7 @@ type scope_def_context = { type scope_var_or_subscope = | ScopeVar of ScopeVar.t - | SubScope of SubScopeName.t * ScopeName.t + | SubScope of SubScopeName.t * (path * ScopeName.t) type scope_context = { var_idmap : scope_var_or_subscope Ident.Map.t; @@ -65,13 +65,10 @@ type var_sig = { type typedef = | TStruct of StructName.t | TEnum of EnumName.t - | TScope of ScopeName.t * scope_out_struct + | TScope of ScopeName.t * scope_info (** Implicitly defined output struct *) type context = { - local_var_idmap : Ast.expr Var.t Ident.Map.t; - (** Inside a definition, local variables can be introduced by functions - arguments or pattern matching *) typedefs : typedef Ident.Map.t; (** Gathers the names of the scopes, structs and enums *) field_idmap : StructField.t StructName.Map.t Ident.Map.t; @@ -82,11 +79,13 @@ type context = { between different enums *) scopes : scope_context ScopeName.Map.t; (** For each scope, its context *) topdefs : TopdefName.t Ident.Map.t; (** Global definitions *) + topdef_types : typ TopdefName.Map.t; structs : struct_context StructName.Map.t; (** For each struct, its context *) enums : enum_context EnumName.Map.t; (** For each enum, its context *) var_typs : var_sig ScopeVar.Map.t; (** The signatures of each scope variable declared *) + modules : context ModuleName.Map.t; } (** Main context used throughout {!module: Desugared.From_surface} *) @@ -131,9 +130,6 @@ val get_params : val is_def_cond : context -> Ast.ScopeDef.t -> bool val is_type_cond : Surface.Ast.typ -> bool -val add_def_local_var : context -> Ident.t -> context * Ast.expr Var.t -(** Adds a binding to the context *) - val get_def_key : Surface.Ast.scope_var -> Surface.Ast.lident Mark.pos option -> @@ -155,6 +151,9 @@ val get_scope : context -> Ident.t Mark.pos -> ScopeName.t (** Find a scope definition from the typedefs, failing if there is none or it has a different kind *) +val module_ctx : context -> path -> context +(** Returns the context corresponding to the given module path; raises a user error if the module is not found *) + val process_type : context -> Surface.Ast.typ -> typ (** Convert a surface base type to an AST type *) diff --git a/compiler/driver.ml b/compiler/driver.ml index 311e00f1..38ab88d8 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -42,30 +42,36 @@ let get_lang options file = @{%s@}, and @{--language@} was not specified" filename) -let load_module_interfaces prg options link_modules = - List.fold_left - (fun prg f -> +let load_module_interfaces options link_modules = + List.map + (fun f -> let lang = get_lang options (FileName f) in let modname = modname_of_file f in - Surface.Parser_driver.add_interface (FileName f) lang [modname] prg) - prg link_modules + let intf = Surface.Parser_driver.load_interface (FileName f) lang in + modname, intf) + link_modules module Passes = struct (* Each pass takes only its cli options, then calls upon its dependent passes (forwarding their options as needed) *) - let surface options : Surface.Ast.program * Cli.backend_lang = + let surface options ~link_modules : Surface.Ast.program * Cli.backend_lang = Message.emit_debug "Reading files..."; let language = get_lang options options.input_file in let prg = Surface.Parser_driver.parse_top_level_file options.input_file language in - Surface.Fill_positions.fill_pos_with_legislative_info prg, language + let prg = + Surface.Fill_positions.fill_pos_with_legislative_info prg + in + let prg = + { prg with program_modules = load_module_interfaces options link_modules } + in + prg, language let desugared options ~link_modules : Desugared.Ast.program * Desugared.Name_resolution.context = - let prg, _ = surface options in - let prg = load_module_interfaces prg options link_modules in + let prg, _ = surface options ~link_modules in Message.emit_debug "Name resolution..."; let ctx = Desugared.Name_resolution.form_context prg in (* let scope_uid = get_scope_uid options backend ctx in @@ -250,7 +256,7 @@ module Commands = struct "Variable @{\"%s\"@} not found inside scope @{\"%a\"@}" variable ScopeName.format scope_uid | Some - (Desugared.Name_resolution.SubScope (subscope_var_name, subscope_name)) + (Desugared.Name_resolution.SubScope (subscope_var_name, (subscope_path, subscope_name))) -> ( match second_part with | None -> @@ -261,6 +267,7 @@ module Commands = struct SubScopeName.format subscope_var_name ScopeName.format scope_uid | Some second_part -> ( match + let ctxt = Desugared.Name_resolution.module_ctx ctxt subscope_path in Ident.Map.find_opt second_part (ScopeName.Map.find subscope_name ctxt.scopes).var_idmap with @@ -299,7 +306,7 @@ module Commands = struct ~output_file ?ext () let makefile options output = - let prg, _ = Passes.surface options in + let prg, _ = Passes.surface options ~link_modules:[] in let backend_extensions_list = [".tex"] in let source_file = match options.Cli.input_file with @@ -330,7 +337,7 @@ module Commands = struct Term.(const makefile $ Cli.Flags.Global.options $ Cli.Flags.output) let html options output print_only_law wrap_weaved_output = - let prg, language = Passes.surface options in + let prg, language = Passes.surface options ~link_modules:[] in Message.emit_debug "Weaving literate program into HTML"; let output_file, with_output = get_output_format options ~ext:".html" output @@ -358,7 +365,7 @@ module Commands = struct $ Cli.Flags.wrap_weaved_output) let latex options output print_only_law wrap_weaved_output = - let prg, language = Passes.surface options in + let prg, language = Passes.surface options ~link_modules:[] in Message.emit_debug "Weaving literate program into LaTeX"; let output_file, with_output = get_output_format options ~ext:".tex" output diff --git a/compiler/driver.mli b/compiler/driver.mli index b5e9c2b1..a66134cf 100644 --- a/compiler/driver.mli +++ b/compiler/driver.mli @@ -25,7 +25,10 @@ val main : unit -> unit Each pass takes only its cli options, then calls upon its dependent passes (forwarding their options as needed) *) module Passes : sig - val surface : Cli.options -> Surface.Ast.program * Cli.backend_lang + val surface : + Cli.options -> + link_modules:string list -> + Surface.Ast.program * Cli.backend_lang val desugared : Cli.options -> diff --git a/compiler/lcalc/ast.ml b/compiler/lcalc/ast.ml index 98c1e90b..a6c07258 100644 --- a/compiler/lcalc/ast.ml +++ b/compiler/lcalc/ast.ml @@ -23,10 +23,10 @@ type 'm program = 'm expr Shared_ast.program module OptionMonad = struct let return ~(mark : 'a mark) e = - Expr.einj e Expr.some_constr Expr.option_enum mark + Expr.einj ~e ~cons:Expr.some_constr ~name:Expr.option_enum mark let empty ~(mark : 'a mark) = - Expr.einj (Expr.elit LUnit mark) Expr.none_constr Expr.option_enum mark + Expr.einj ~e:(Expr.elit LUnit mark) ~cons:Expr.none_constr ~name:Expr.option_enum mark let bind_var ~(mark : 'a mark) f x arg = let cases = @@ -36,7 +36,7 @@ module OptionMonad = struct let x = Var.make "_" in Expr.eabs (Expr.bind [| x |] - (Expr.einj (Expr.evar x mark) Expr.none_constr Expr.option_enum + (Expr.einj ~e:(Expr.evar x mark) ~cons:Expr.none_constr ~name:Expr.option_enum mark)) [TLit TUnit, Expr.mark_pos mark] mark ); @@ -46,7 +46,7 @@ module OptionMonad = struct (*| Some x -> f (where f contains x as a free variable) *); ] in - Expr.ematch arg Expr.option_enum cases mark + Expr.ematch ~e:arg ~name:Expr.option_enum ~cases mark let bind ~(mark : 'a mark) ~(var_name : string) f arg = let x = Var.make var_name in @@ -86,8 +86,8 @@ module OptionMonad = struct ListLabels.fold_left2 xs args ~f:(bind_var ~mark) ~init: (Expr.einj - (Expr.eapp f (List.map (fun v -> Expr.evar v mark) xs) mark) - Expr.some_constr Expr.option_enum mark) + ~e:(Expr.eapp f (List.map (fun v -> Expr.evar v mark) xs) mark) + ~cons:Expr.some_constr ~name:Expr.option_enum mark) let map_var ~(mark : 'a mark) f x arg = mmap_mvar f [x] [arg] ~mark @@ -120,6 +120,6 @@ module OptionMonad = struct Expr.some_constr, Expr.fun_id ~var_name mark (* | Some x -> x*); ] in - if toplevel then Expr.ematch arg Expr.option_enum cases mark - else return ~mark (Expr.ematch arg Expr.option_enum cases mark) + if toplevel then Expr.ematch ~e:arg ~name:Expr.option_enum ~cases mark + else return ~mark (Expr.ematch ~e:arg ~name:Expr.option_enum ~cases mark) end diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 4fd9d95f..510ff89d 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -70,7 +70,7 @@ let rec transform_closures_expr : cases (free_vars, EnumConstructor.Map.empty) in - free_vars, Expr.ematch new_e name new_cases m + free_vars, Expr.ematch ~e:new_e ~name ~cases:new_cases m | EApp { f = EAbs { binder; tys }, e1_pos; args } -> (* let-binding, we should not close these *) let vars, body = Bindlib.unmbind binder in @@ -333,11 +333,11 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = | TEnum e -> EnumConstructor.Map.exists (fun _ t' -> type_contains_arrow t') - (EnumName.Map.find e p.decl_ctx.ctx_enums) + (snd (EnumName.Map.find e p.decl_ctx.ctx_enums)) | TStruct s -> StructField.Map.exists (fun _ t' -> type_contains_arrow t') - (StructName.Map.find s p.decl_ctx.ctx_structs) + (snd (StructName.Map.find s p.decl_ctx.ctx_structs)) in let replace_fun_typs t = if type_contains_arrow t then Mark.copy t TAny else t @@ -346,11 +346,11 @@ let transform_closures_program (p : 'm program) : 'm program Bindlib.box = p.decl_ctx with ctx_structs = StructName.Map.map - (StructField.Map.map replace_fun_typs) + (fun (p, def) -> p, StructField.Map.map replace_fun_typs def) p.decl_ctx.ctx_structs; ctx_enums = EnumName.Map.map - (EnumConstructor.Map.map replace_fun_typs) + (fun (p, def) -> p, EnumConstructor.Map.map replace_fun_typs def) p.decl_ctx.ctx_enums; } in @@ -394,7 +394,7 @@ let rec hoist_closures_expr : cases (collected_closures, EnumConstructor.Map.empty) in - collected_closures, Expr.ematch new_e name new_cases m + collected_closures, Expr.ematch ~e:new_e ~name ~cases:new_cases m | EApp { f = EAbs { binder; tys }, e1_pos; args } -> (* let-binding, we should not close these *) let vars, body = Bindlib.unmbind binder in diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index 6eeb34e5..40f88e06 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -113,7 +113,7 @@ let rec trans (ctx : typed ctx) (e : typed D.expr) : (lcalc, typed) boxed_gexpr if (Var.Map.find x ctx.ctx_vars).info_pure then Ast.OptionMonad.return (Expr.evar (trans_var ctx x) m) ~mark else Expr.evar (trans_var ctx x) m - | EExternal eref -> Expr.eexternal eref mark + | EExternal _ as e -> Expr.map ~f:(trans ctx) (e, m) | EApp { f = EVar v, _; args = [(ELit LUnit, _)] } -> (* Invariant: as users cannot write thunks, it can only come from prior compilation passes. Hence we can safely remove those. *) @@ -169,7 +169,7 @@ let rec trans (ctx : typed ctx) (e : typed D.expr) : (lcalc, typed) boxed_gexpr Ast.OptionMonad.return ~mark (Expr.eapp (Expr.evar (trans_var ctx scope) mark) - [Expr.estruct name (StructField.Map.map (trans ctx) fields) mark] + [Expr.estruct ~name ~fields:(StructField.Map.map (trans ctx) fields) mark] mark) | EApp { f = (EVar ff, _) as f; args } when not (Var.Map.find ff ctx.ctx_vars).is_scope -> @@ -395,7 +395,7 @@ let rec trans (ctx : typed ctx) (e : typed D.expr) : (lcalc, typed) boxed_gexpr in Ast.OptionMonad.bind_cont ~var_name:(context_or_same_var ctx e) - (fun e -> Expr.ematch (Expr.evar e m) name cases m) + (fun e -> Expr.ematch ~e:(Expr.evar e m) ~name ~cases m) (trans ctx e) ~mark | EArray args -> Ast.OptionMonad.mbind_cont ~mark ~var_name:ctx.ctx_context_name @@ -418,7 +418,7 @@ let rec trans (ctx : typed ctx) (e : typed D.expr) : (lcalc, typed) boxed_gexpr xs) ~f:StructField.Map.add ~init:StructField.Map.empty in - Ast.OptionMonad.return ~mark (Expr.estruct name fields mark)) + Ast.OptionMonad.return ~mark (Expr.estruct ~name ~fields mark)) (List.map (trans ctx) fields) ~mark | EIfThenElse { cond; etrue; efalse } -> @@ -433,12 +433,12 @@ let rec trans (ctx : typed ctx) (e : typed D.expr) : (lcalc, typed) boxed_gexpr ~var_name:(context_or_same_var ctx e) (fun e -> Ast.OptionMonad.return ~mark - (Expr.einj (Expr.evar e mark) cons name mark)) + (Expr.einj ~e:(Expr.evar e mark) ~cons ~name mark)) (trans ctx e) ~mark | EStructAccess { name; e; field } -> Ast.OptionMonad.bind_cont ~var_name:(context_or_same_var ctx e) - (fun e -> Expr.estructaccess (Expr.evar e mark) field name mark) + (fun e -> Expr.estructaccess ~e:(Expr.evar e mark) ~field ~name mark) (trans ctx e) ~mark | ETuple args -> Ast.OptionMonad.mbind_cont ~var_name:ctx.ctx_context_name @@ -653,8 +653,8 @@ and trans_scope_body_expr ctx s : Bindlib.box_apply (fun e -> Result e) (Expr.Box.lift - @@ Expr.estruct name - (StructField.Map.map (trans ctx) fields) + @@ Expr.estruct ~name + ~fields:(StructField.Map.map (trans ctx) fields) (Mark.get e)) | _ -> assert false end @@ -741,7 +741,7 @@ let translate_program (prgm : typed D.program) : untyped A.program = prgm.decl_ctx with ctx_enums = prgm.decl_ctx.ctx_enums - |> EnumName.Map.add Expr.option_enum Expr.option_enum_config; + |> EnumName.Map.add Expr.option_enum ([], Expr.option_enum_config); } in let decl_ctx = @@ -749,7 +749,8 @@ let translate_program (prgm : typed D.program) : untyped A.program = decl_ctx with ctx_structs = prgm.decl_ctx.ctx_structs - |> StructName.Map.mapi (fun _n str -> + |> StructName.Map.mapi (fun _n (path, str) -> + path, StructField.Map.map trans_typ_keep str); } in diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index c6290e2c..13366e40 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -19,22 +19,6 @@ open Shared_ast open Ast module D = Dcalc.Ast -let find_struct (s : StructName.t) (ctx : decl_ctx) : typ StructField.Map.t = - try StructName.Map.find s ctx.ctx_structs - with Not_found -> - let s_name, pos = StructName.get_info s in - Message.raise_spanned_error pos - "Internal Error: Structure %s was not found in the current environment." - s_name - -let find_enum (en : EnumName.t) (ctx : decl_ctx) : typ EnumConstructor.Map.t = - try EnumName.Map.find en ctx.ctx_enums - with Not_found -> - let en_name, pos = EnumName.get_info en in - Message.raise_spanned_error pos - "Internal Error: Enumeration %s was not found in the current environment." - en_name - let format_lit (fmt : Format.formatter) (l : lit Mark.pos) : unit = match Mark.remove l with | LBool b -> Print.lit fmt (LBool b) @@ -233,9 +217,9 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit = | TAny -> Format.fprintf fmt "_" | TClosureEnv -> failwith "unimplemented!" -let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit = +let format_var_str (fmt : Format.formatter) (v : string) : unit = let lowercase_name = - String.to_snake_case (String.to_ascii (Bindlib.name_of v)) + String.to_snake_case (String.to_ascii v) in let lowercase_name = Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") @@ -245,11 +229,15 @@ let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit = let lowercase_name = String.to_ascii lowercase_name in if List.mem lowercase_name ["handle_default"; "handle_default_opt"] - || String.begins_with_uppercase (Bindlib.name_of v) + (* O_O *) + || String.begins_with_uppercase v then Format.pp_print_string fmt lowercase_name else if lowercase_name = "_" then Format.pp_print_string fmt lowercase_name else Format.fprintf fmt "%s_" lowercase_name +let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit = + format_var_str fmt (Bindlib.name_of v) + let needs_parens (e : 'm expr) : bool = match Mark.remove e with | EApp { f = EAbs _, _; _ } @@ -288,7 +276,13 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : in match Mark.remove e with | EVar v -> Format.fprintf fmt "%a" format_var v - | EExternal qid -> Qident.format fmt qid + | EExternal { path; name } -> + Print.path fmt path; + (* FIXME: this is wrong in general !! + We assume the idents exposed by the module depend only on the original name, while they actually get through Bindlib and may have been renamed. A correct implem could use the runtime registration used by the interpreter, but that would be distasteful and incur a penalty ; or we would need to reproduce the same structure as in the original module to ensure that bindlib performs the exact same renamings ; or finally we could normalise the names at generation time (either at toplevel or in a dedicated submodule ?) *) + (match Mark.remove name with + | External_value name -> format_var_str fmt (Mark.remove (TopdefName.get_info name)) + | External_scope name -> format_var_str fmt (Mark.remove (ScopeName.get_info name))) | ETuple es -> Format.fprintf fmt "@[(%a)@]" (Format.pp_print_list @@ -555,9 +549,13 @@ let format_ctx (fun struct_or_enum -> match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> - Format.fprintf fmt "%a@\n" format_struct_decl (s, find_struct s ctx) + let path, def = StructName.Map.find s ctx.ctx_structs in + if path = [] then + Format.fprintf fmt "%a@\n" format_struct_decl (s, def) | Scopelang.Dependency.TVertex.Enum e -> - Format.fprintf fmt "%a@\n" format_enum_decl (e, find_enum e ctx)) + let path, def = EnumName.Map.find e ctx.ctx_enums in + if path = [] then + Format.fprintf fmt "%a@\n" format_enum_decl (e, def)) (type_ordering @ scope_structs) let rename_vars e = @@ -616,7 +614,7 @@ let format_scope_exec scope_body = let scope_name_str = Mark.remove (ScopeName.get_info scope_name) in let scope_var = String.Map.find scope_name_str bnd in - let scope_input = + let _, scope_input = StructName.Map.find scope_body.scope_body_input_struct ctx.ctx_structs in if not (StructField.Map.is_empty scope_input) then diff --git a/compiler/lcalc/to_ocaml.mli b/compiler/lcalc/to_ocaml.mli index f695a0f3..618813ed 100644 --- a/compiler/lcalc/to_ocaml.mli +++ b/compiler/lcalc/to_ocaml.mli @@ -19,8 +19,6 @@ open Shared_ast (** Formats a lambda calculus program into a valid OCaml program *) val avoid_keywords : string -> string -val find_struct : StructName.t -> decl_ctx -> typ StructField.Map.t -val find_enum : EnumName.t -> decl_ctx -> typ EnumConstructor.Map.t val typ_needs_parens : typ -> bool (* val needs_parens : 'm expr -> bool *) diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index 52f7c390..4cd232fe 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -139,14 +139,14 @@ module To_jsoo = struct | TArrow _ -> Format.fprintf fmt "Js.meth" | _ -> Format.fprintf fmt "Js.readonly_prop" in - let format_struct_decl fmt (struct_name, struct_fields) = + let format_struct_decl fmt (struct_name, (path, struct_fields)) = let fmt_struct_name fmt _ = format_struct_name fmt struct_name in let fmt_module_struct_name fmt _ = + Print.path fmt path; To_ocaml.format_to_module_name fmt (`Sname struct_name) in let fmt_to_jsoo fmt _ = - Format.fprintf fmt "%a" - (Format.pp_print_list + Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (struct_field, struct_field_type) -> match Mark.remove struct_field_type with @@ -172,12 +172,12 @@ module To_jsoo = struct Format.fprintf fmt "@[val %a =@ %a %a.%a@]" format_struct_field_name_camel_case struct_field format_typ_to_jsoo struct_field_type fmt_struct_name () - format_struct_field_name (None, struct_field))) + format_struct_field_name (None, struct_field)) + fmt (StructField.Map.bindings struct_fields) in let fmt_of_jsoo fmt _ = - Format.fprintf fmt "%a" - (Format.pp_print_list + Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") (fun fmt (struct_field, struct_field_type) -> match Mark.remove struct_field_type with @@ -192,7 +192,8 @@ module To_jsoo = struct "@[%a =@ @[%a@ @[%a@,##.%a@]@]@]" format_struct_field_name (None, struct_field) format_typ_of_jsoo struct_field_type fmt_struct_name () - format_struct_field_name_camel_case struct_field)) + format_struct_field_name_camel_case struct_field) + fmt (StructField.Map.bindings struct_fields) in let fmt_conv_funs fmt _ = @@ -230,10 +231,11 @@ module To_jsoo = struct (StructField.Map.bindings struct_fields) fmt_conv_funs () in - let format_enum_decl fmt (enum_name, (enum_cons : typ EnumConstructor.Map.t)) + let format_enum_decl fmt (enum_name, (path, (enum_cons : typ EnumConstructor.Map.t))) = let fmt_enum_name fmt _ = format_enum_name fmt enum_name in - let fmt_module_enum_name fmt _ = + let fmt_module_enum_name fmt () = + Print.path fmt path; To_ocaml.format_to_module_name fmt (`Ename enum_name) in let fmt_to_jsoo fmt _ = @@ -332,9 +334,11 @@ module To_jsoo = struct (fun struct_or_enum -> match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> - Format.fprintf fmt "%a@\n" format_struct_decl (s, find_struct s ctx) + Format.fprintf fmt "%a@\n" format_struct_decl + (s, StructName.Map.find s ctx.ctx_structs) | Scopelang.Dependency.TVertex.Enum e -> - Format.fprintf fmt "%a@\n" format_enum_decl (e, find_enum e ctx)) + Format.fprintf fmt "%a@\n" format_enum_decl + (e, EnumName.Map.find e ctx.ctx_enums)) (type_ordering @ scope_structs) let fmt_input_struct_name fmt (scope_body : 'a expr scope_body) = diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index 3cb5824d..8bb108b9 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -76,13 +76,15 @@ module To_json = struct (ctx : decl_ctx) (fmt : Format.formatter) (sname : StructName.t) = + let path, fields = StructName.Map.find sname ctx.ctx_structs in Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n") (fun fmt (field_name, field_type) -> - Format.fprintf fmt "@[\"%a\": {@\n%a@]@\n}" + Format.fprintf fmt "@[\"%a%a\": {@\n%a@]@\n}" + Print.path path format_struct_field_name_camel_case field_name fmt_type field_type) fmt - (StructField.Map.bindings (find_struct sname ctx)) + (StructField.Map.bindings fields) let fmt_definitions (ctx : decl_ctx) @@ -103,17 +105,18 @@ module To_json = struct (t :: acc) @ collect_required_type_defs_from_scope_input s | TEnum e -> List.fold_left collect (t :: acc) - (EnumConstructor.Map.values (EnumName.Map.find e ctx.ctx_enums)) + (EnumConstructor.Map.values (snd (EnumName.Map.find e ctx.ctx_enums))) | TArray t -> collect acc t | _ -> acc in - find_struct input_struct ctx + StructName.Map.find input_struct ctx.ctx_structs + |> snd |> StructField.Map.values |> List.fold_left (fun acc field_typ -> collect acc field_typ) [] |> List.sort_uniq (fun t t' -> String.compare (get_name t) (get_name t')) in let fmt_enum_properties fmt ename = - let enum_def = find_enum ename ctx in + let _path, enum_def = EnumName.Map.find ename ctx.ctx_enums in Format.fprintf fmt "@[\"kind\": {@\n\ \"type\": \"string\",@\n\ diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml index b9916515..55f8b82b 100644 --- a/compiler/plugins/lazy_interp.ml +++ b/compiler/plugins/lazy_interp.ml @@ -233,8 +233,8 @@ let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : log "====================="; let m = Mark.get e in let application_arg = - Expr.estruct scope_arg_struct - (StructField.Map.map + Expr.estruct ~name:scope_arg_struct + ~fields:(StructField.Map.map (function | TArrow (ty_in, ty_out), _ -> Expr.make_abs @@ -242,7 +242,7 @@ let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : (Bindlib.box EEmptyError, Expr.with_ty m ty_out) ty_in (Expr.mark_pos m) | ty -> Expr.evar (Var.make "undefined_input") (Expr.with_ty m ty)) - (StructName.Map.find scope_arg_struct ctx.ctx_structs)) + (snd (StructName.Map.find scope_arg_struct ctx.ctx_structs))) m in let e_app = Expr.eapp (Expr.box e) [application_arg] m in diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index 8e01b45a..dcbebc2c 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -42,7 +42,8 @@ let rec format_expr | EVar v -> Format.fprintf fmt "%a" format_var_name v | EFunc v -> Format.fprintf fmt "%a" format_func_name v | EStruct (es, s) -> - Format.fprintf fmt "@[%a@ %a%a%a@]" StructName.format s + let path, fields = StructName.Map.find s decl_ctx.ctx_structs in + Format.fprintf fmt "@[%a%a@ %a%a%a@]" Print.path path StructName.format s Print.punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") @@ -50,8 +51,7 @@ let rec format_expr Format.fprintf fmt "%a%a%a%a %a" Print.punctuation "\"" StructField.format struct_field Print.punctuation "\"" Print.punctuation ":" format_expr e)) - (List.combine es - (StructField.Map.bindings (StructName.Map.find s decl_ctx.ctx_structs))) + (List.combine es (StructField.Map.bindings fields)) Print.punctuation "}" | EArray es -> Format.fprintf fmt "@[%a%a%a@]" Print.punctuation "[" @@ -142,20 +142,20 @@ let rec format_statement (format_expr decl_ctx ~debug) (naked_expr, Mark.get stmt) | SSwitch (e_switch, enum, arms) -> + let path, cons = EnumName.Map.find enum decl_ctx.ctx_enums in Format.fprintf fmt "@[%a @[%a@]%a@]%a" Print.keyword "switch" (format_expr decl_ctx ~debug) e_switch Print.punctuation ":" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt ((case, _), (arm_block, payload_name)) -> - Format.fprintf fmt "%a %a%a@ %a @[%a@ %a@]" Print.punctuation - "|" Print.enum_constructor case Print.punctuation ":" + Format.fprintf fmt "%a %a%a%a@ %a @[%a@ %a@]" Print.punctuation + "|" Print.path path Print.enum_constructor case Print.punctuation ":" format_var_name payload_name Print.punctuation "→" (format_block decl_ctx ~debug) arm_block)) (List.combine - (EnumConstructor.Map.bindings - (EnumName.Map.find enum decl_ctx.ctx_enums)) + (EnumConstructor.Map.bindings cons) arms) and format_block diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index bc3a1a67..59765149 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -274,14 +274,16 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : | EVar v -> format_var fmt v | EFunc f -> format_func_name fmt f | EStruct (es, s) -> - Format.fprintf fmt "%a(%a)" format_struct_name s + let path, fields = + StructName.Map.find s ctx.ctx_structs + in + Format.fprintf fmt "%a%a(%a)" Print.path path format_struct_name s (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (e, (struct_field, _)) -> Format.fprintf fmt "%a = %a" format_struct_field_name struct_field (format_expression ctx) e)) - (List.combine es - (StructField.Map.bindings (StructName.Map.find s ctx.ctx_structs))) + (List.combine es (StructField.Map.bindings fields)) | EStructFieldAccess (e1, field, _) -> Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field @@ -426,11 +428,12 @@ let rec format_statement (format_block ctx) case_none format_var case_some_var format_var tmp_var (format_block ctx) case_some | SSwitch (e1, e_name, cases) -> + let path, cons_map = EnumName.Map.find e_name ctx.ctx_enums in let cases = List.map2 (fun (x, y) (cons, _) -> x, y, cons) cases - (EnumConstructor.Map.bindings (EnumName.Map.find e_name ctx.ctx_enums)) + (EnumConstructor.Map.bindings cons_map) in let tmp_var = VarName.fresh ("match_arg", Pos.no_pos) in Format.fprintf fmt "%a = %a@\n@[if %a@]" format_var tmp_var @@ -438,8 +441,8 @@ let rec format_statement (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") (fun fmt (case_block, payload_var, cons_name) -> - Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" - format_var tmp_var format_enum_name e_name format_enum_cons_name + Format.fprintf fmt "%a.code == %a%a_Code.%a:@\n%a = %a.value@\n%a" + format_var tmp_var Print.path path format_enum_name e_name format_enum_cons_name cons_name format_var payload_var format_var tmp_var (format_block ctx) case_block)) cases @@ -584,10 +587,10 @@ let format_ctx match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> Format.fprintf fmt "%a@\n@\n" format_struct_decl - (s, StructName.Map.find s ctx.ctx_structs) + (s, snd (StructName.Map.find s ctx.ctx_structs)) | Scopelang.Dependency.TVertex.Enum e -> Format.fprintf fmt "%a@\n@\n" format_enum_decl - (e, EnumName.Map.find e ctx.ctx_enums)) + (e, snd (EnumName.Map.find e ctx.ctx_enums))) (type_ordering @ scope_structs) let format_program diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index 94382ab7..163dbee4 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -41,19 +41,19 @@ let rec locations_used (e : 'm expr) : LocationSet.t = type 'm rule = | Definition of location Mark.pos * typ * Desugared.Ast.io * 'm expr | Assertion of 'm expr - | Call of ScopeName.t * SubScopeName.t * 'm mark + | Call of (path * ScopeName.t) * SubScopeName.t * 'm mark type 'm scope_decl = { scope_decl_name : ScopeName.t; scope_sig : (typ * Desugared.Ast.io) ScopeVar.Map.t; scope_decl_rules : 'm rule list; - scope_mark : 'm mark; scope_options : Desugared.Ast.catala_option Mark.pos list; } type 'm program = { - program_scopes : 'm scope_decl ScopeName.Map.t; + program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; program_topdefs : ('m expr * typ) TopdefName.Map.t; + program_modules : nil program ModuleName.Map.t; program_ctx : decl_ctx; } @@ -70,11 +70,34 @@ let type_rule decl_ctx env = function Call (sc_name, ssc_name, Typed { pos; ty = Mark.add pos TAny }) let type_program (prg : 'm program) : typed program = + let base_typing_env prg = + let typing_env = Typing.Env.empty prg.program_ctx in + let typing_env = + TopdefName.Map.fold + (fun name (_, ty) -> Typing.Env.add_toplevel_var name ty) + prg.program_topdefs + typing_env + in + let typing_env = + ScopeName.Map.fold + (fun scope_name scope_decl -> + let vars = ScopeVar.Map.map fst (Mark.remove scope_decl).scope_sig in + Typing.Env.add_scope scope_name ~vars) + prg.program_scopes typing_env + in + typing_env + in + let rec build_typing_env prg = + ModuleName.Map.fold (fun modname prg -> + Typing.Env.add_module modname ~module_env:(build_typing_env prg)) + prg.program_modules + (base_typing_env prg) + in let typing_env = - TopdefName.Map.fold - (fun name (_, ty) -> Typing.Env.add_toplevel_var name ty) - prg.program_topdefs - (Typing.Env.empty prg.program_ctx) + ModuleName.Map.fold (fun modname prg -> + Typing.Env.add_module modname ~module_env:(build_typing_env prg)) + prg.program_modules + (base_typing_env prg) in let program_topdefs = TopdefName.Map.map @@ -85,16 +108,9 @@ let type_program (prg : 'm program) : typed program = typ )) prg.program_topdefs in - let typing_env = - ScopeName.Map.fold - (fun scope_name scope_decl -> - let vars = ScopeVar.Map.map fst scope_decl.scope_sig in - Typing.Env.add_scope scope_name ~vars) - prg.program_scopes typing_env - in let program_scopes = ScopeName.Map.map - (fun scope_decl -> + (Mark.map (fun scope_decl -> let typing_env = ScopeVar.Map.fold (fun svar (typ, _) env -> Typing.Env.add_scope_var svar typ env) @@ -105,11 +121,7 @@ let type_program (prg : 'm program) : typed program = (type_rule prg.program_ctx typing_env) scope_decl.scope_decl_rules in - let scope_mark = - let pos = Mark.get (ScopeName.get_info scope_decl.scope_decl_name) in - Typed { pos; ty = Mark.add pos TAny } - in - { scope_decl with scope_decl_rules; scope_mark }) + {scope_decl with scope_decl_rules})) prg.program_scopes in { prg with program_topdefs; program_scopes } diff --git a/compiler/scopelang/ast.mli b/compiler/scopelang/ast.mli index 069c8b06..0c23a772 100644 --- a/compiler/scopelang/ast.mli +++ b/compiler/scopelang/ast.mli @@ -34,19 +34,20 @@ val locations_used : 'm expr -> LocationSet.t type 'm rule = | Definition of location Mark.pos * typ * Desugared.Ast.io * 'm expr | Assertion of 'm expr - | Call of ScopeName.t * SubScopeName.t * 'm mark + | Call of (path * ScopeName.t) * SubScopeName.t * 'm mark type 'm scope_decl = { scope_decl_name : ScopeName.t; scope_sig : (typ * Desugared.Ast.io) ScopeVar.Map.t; scope_decl_rules : 'm rule list; - scope_mark : 'm mark; scope_options : Desugared.Ast.catala_option Mark.pos list; } type 'm program = { - program_scopes : 'm scope_decl ScopeName.Map.t; + program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; program_topdefs : ('m expr * typ) TopdefName.Map.t; + program_modules : nil program ModuleName.Map.t; + (* Using [nil] here ensure that program interfaces don't contain any expressions. They won't contain any rules or topdefs, but will still have the scope signatures needed to respect the call convention *) program_ctx : decl_ctx; } diff --git a/compiler/scopelang/dependency.ml b/compiler/scopelang/dependency.ml index dff5efe1..f68921fa 100644 --- a/compiler/scopelang/dependency.ml +++ b/compiler/scopelang/dependency.ml @@ -82,8 +82,8 @@ let rec expr_used_defs e = e VMap.empty in match e with - | ELocation (ToplevelVar (v, pos)), _ -> VMap.singleton (Topdef v) pos - | (EScopeCall { scope; _ }, m) as e -> + | ELocation (ToplevelVar { path = []; name = v, pos }), _ -> VMap.singleton (Topdef v) pos + | (EScopeCall { path = []; scope; _ }, m) as e -> VMap.add (Scope scope) (Expr.mark_pos m) (recurse_subterms e) | EAbs { binder; _ }, _ -> let _, body = Bindlib.unmbind binder in @@ -95,7 +95,8 @@ let rule_used_defs = function (* TODO: maybe this info could be passed on from previous passes without walking through all exprs again *) expr_used_defs e - | Ast.Call (subscope, subindex, _) -> + | Ast.Call ((_::_path, _), _, _) -> VMap.empty + | Ast.Call (([], subscope), subindex, _) -> VMap.singleton (Scope subscope) (Mark.get (SubScopeName.get_info subindex)) let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t = @@ -128,7 +129,7 @@ let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t = prgm.program_topdefs g in ScopeName.Map.fold - (fun scope_name scope g -> + (fun scope_name (scope, _) g -> List.fold_left (fun g rule -> let used_defs = rule_used_defs rule in @@ -147,6 +148,7 @@ let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t = used_defs g) g scope.Ast.scope_decl_rules) prgm.program_scopes g +(* TODO FIXME: Add submodules here, they may still need dependency resolution type-wise (?) *) let check_for_cycle_in_defs (g : SDependencies.t) : unit = (* if there is a cycle, there will be an strongly connected component of @@ -270,7 +272,7 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t let g = TDependencies.empty in let g = StructName.Map.fold - (fun s fields g -> + (fun s (path, fields) g -> StructField.Map.fold (fun _ typ g -> let def = TVertex.Struct s in @@ -280,8 +282,9 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t (fun used g -> if TVertex.equal used def then Message.raise_spanned_error (Mark.get typ) - "The type %a is defined using itself, which is forbidden \ + "The type %a%a is defined using itself, which is forbidden \ since Catala does not provide recursive types" + Print.path path TVertex.format used else let edge = TDependencies.E.create used (Mark.get typ) def in @@ -292,7 +295,7 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t in let g = EnumName.Map.fold - (fun e cases g -> + (fun e (path, cases) g -> EnumConstructor.Map.fold (fun _ typ g -> let def = TVertex.Enum e in @@ -302,8 +305,9 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t (fun used g -> if TVertex.equal used def then Message.raise_spanned_error (Mark.get typ) - "The type %a is defined using itself, which is forbidden \ + "The type %a%a is defined using itself, which is forbidden \ since Catala does not provide recursive types" + Print.path path TVertex.format used else let edge = TDependencies.E.create used (Mark.get typ) def in diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index c8a44253..a8862598 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -29,9 +29,18 @@ type target_scope_vars = type ctx = { decl_ctx : decl_ctx; scope_var_mapping : target_scope_vars ScopeVar.Map.t; - var_mapping : (Desugared.Ast.expr, untyped Ast.expr Var.t) Var.Map.t; + var_mapping : (D.expr, untyped Ast.expr Var.t) Var.Map.t; + modules : ctx ModuleName.Map.t; } +let rec module_ctx ctx = function + | [] -> ctx + | (modname, mpos) :: path -> + match ModuleName.Map.find_opt modname ctx.modules with + | None -> + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format modname + | Some ctx -> module_ctx ctx path + let tag_with_log_entry (e : untyped Ast.expr boxed) (l : log_entry) @@ -42,7 +51,7 @@ let tag_with_log_entry [e] (Mark.get e) else e -let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : +let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed = let m = Mark.get e in match Mark.remove e with @@ -57,28 +66,33 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : ctx (Array.to_list vars) (Array.to_list new_vars) in Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) tys m - | ELocation (SubScopeVar (s_name, ss_name, s_var)) -> + | ELocation (SubScopeVar { path; scope; alias; var }) -> (* When referring to a subscope variable in an expression, we are referring to the output, hence we take the last state. *) - let new_s_var = - match ScopeVar.Map.find (Mark.remove s_var) ctx.scope_var_mapping with - | WholeVar new_s_var -> Mark.copy s_var new_s_var - | States states -> Mark.copy s_var (snd (List.hd (List.rev states))) + let ctx = module_ctx ctx path in + let var = + match ScopeVar.Map.find (Mark.remove var) ctx.scope_var_mapping with + | WholeVar new_s_var -> Mark.copy var new_s_var + | States states -> Mark.copy var (snd (List.hd (List.rev states))) in - Expr.elocation (SubScopeVar (s_name, ss_name, new_s_var)) m - | ELocation (DesugaredScopeVar (s_var, None)) -> + Expr.elocation (SubScopeVar { path; scope; alias; var}) m + | ELocation (DesugaredScopeVar { name; state = None }) -> Expr.elocation (ScopelangScopeVar - (match ScopeVar.Map.find (Mark.remove s_var) ctx.scope_var_mapping with - | WholeVar new_s_var -> Mark.copy s_var new_s_var - | States _ -> failwith "should not happen")) + { name = + match ScopeVar.Map.find (Mark.remove name) ctx.scope_var_mapping + with + | WholeVar new_s_var -> Mark.copy name new_s_var + | States _ -> failwith "should not happen" } ) m - | ELocation (DesugaredScopeVar (s_var, Some state)) -> - Expr.elocation + | ELocation (DesugaredScopeVar { name; state = Some state }) -> + Expr.elocation (ScopelangScopeVar - (match ScopeVar.Map.find (Mark.remove s_var) ctx.scope_var_mapping with - | WholeVar _ -> failwith "should not happen" - | States states -> Mark.copy s_var (List.assoc state states))) + { name = + match ScopeVar.Map.find (Mark.remove name) ctx.scope_var_mapping + with + | WholeVar _ -> failwith "should not happen" + | States states -> Mark.copy name (List.assoc state states) }) m | ELocation (ToplevelVar v) -> Expr.elocation (ToplevelVar v) m | EDStructAccess { name_opt = None; _ } -> @@ -87,7 +101,7 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : one possible matching structure *) Message.raise_spanned_error (Expr.mark_pos m) "Ambiguous structure field access" - | EDStructAccess { e; field; name_opt = Some name } -> + | EDStructAccess { e; field; path = _; name_opt = Some name } -> let e' = translate_expr ctx e in let field = try @@ -100,10 +114,10 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : @{\"%a\"@}" field StructName.format name in - Expr.estructaccess e' field name m - | EScopeCall { scope; args } -> - Expr.escopecall scope - (ScopeVar.Map.fold + Expr.estructaccess ~e:e' ~field ~name m + | EScopeCall { path; scope; args } -> + Expr.escopecall ~path ~scope + ~args:(ScopeVar.Map.fold (fun v e args' -> let v' = match ScopeVar.Map.find v ctx.scope_var_mapping with @@ -132,7 +146,7 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : | EOp _ -> assert false (* Only allowed within [EApp] *) | ( EStruct _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | ELit _ | EApp _ | EDefault _ | EIfThenElse _ | EArray _ | EEmptyError - | EErrorOnEmpty _ | EExternal _ ) as e -> + | EErrorOnEmpty _) as e -> Expr.map ~f:(translate_expr ctx) (e, m) (** {1 Rule tree construction} *) @@ -140,31 +154,31 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : (** Intermediate representation for the exception tree of rules for a particular scope definition. *) type rule_tree = - | Leaf of Desugared.Ast.rule list + | Leaf of D.rule list (** Rules defining a base case piecewise. List is non-empty. *) - | Node of rule_tree list * Desugared.Ast.rule list + | Node of rule_tree list * D.rule list (** [Node (exceptions, base_case)] is a list of exceptions to a non-empty list of rules defining a base case piecewise. *) (** Transforms a flat list of rules into a tree, taking into account the priorities declared between rules *) let def_to_exception_graph - (def_info : Desugared.Ast.ScopeDef.t) - (def : Desugared.Ast.rule RuleName.Map.t) : + (def_info : D.ScopeDef.t) + (def : D.rule RuleName.Map.t) : Desugared.Dependency.ExceptionsDependencies.t = let exc_graph = Desugared.Dependency.build_exceptions_graph def def_info in Desugared.Dependency.check_for_exception_cycle def exc_graph; exc_graph -let rule_to_exception_graph (scope : Desugared.Ast.scope) = function +let rule_to_exception_graph (scope : D.scope) = function | Desugared.Dependency.Vertex.Var (var, state) -> ( let scope_def = - Desugared.Ast.ScopeDef.Map.find - (Desugared.Ast.ScopeDef.Var (var, state)) + D.ScopeDef.Map.find + (D.ScopeDef.Var (var, state)) scope.scope_defs in let var_def = scope_def.D.scope_def_rules in - match Mark.remove scope_def.Desugared.Ast.scope_def_io.io_input with + match Mark.remove scope_def.D.scope_def_io.io_input with | OnlyInput when not (RuleName.Map.is_empty var_def) -> (* If the variable is tagged as input, then it shall not be redefined. *) Message.raise_multispanned_error @@ -176,29 +190,29 @@ let rule_to_exception_graph (scope : Desugared.Ast.scope) = function (RuleName.Map.keys var_def)) "It is impossible to give a definition to a scope variable tagged as \ input." - | OnlyInput -> Desugared.Ast.ScopeDef.Map.empty + | OnlyInput -> D.ScopeDef.Map.empty (* we do not provide any definition for an input-only variable *) | _ -> - Desugared.Ast.ScopeDef.Map.singleton - (Desugared.Ast.ScopeDef.Var (var, state)) + D.ScopeDef.Map.singleton + (D.ScopeDef.Var (var, state)) (def_to_exception_graph - (Desugared.Ast.ScopeDef.Var (var, state)) + (D.ScopeDef.Var (var, state)) var_def)) | Desugared.Dependency.Vertex.SubScope sub_scope_index -> (* Before calling the sub_scope, we need to include all the re-definitions of subscope parameters*) let sub_scope_vars_redefs_candidates = - Desugared.Ast.ScopeDef.Map.filter + D.ScopeDef.Map.filter (fun def_key scope_def -> match def_key with - | Desugared.Ast.ScopeDef.Var _ -> false - | Desugared.Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) -> + | D.ScopeDef.Var _ -> false + | D.ScopeDef.SubScopeVar (sub_scope_index', _, _) -> sub_scope_index = sub_scope_index' (* We exclude subscope variables that have 0 re-definitions and are not visible in the input of the subscope *) && not ((match - Mark.remove scope_def.Desugared.Ast.scope_def_io.io_input + Mark.remove scope_def.D.scope_def_io.io_input with | NoInput -> true | _ -> false) @@ -206,18 +220,18 @@ let rule_to_exception_graph (scope : Desugared.Ast.scope) = function scope.scope_defs in let sub_scope_vars_redefs = - Desugared.Ast.ScopeDef.Map.mapi + D.ScopeDef.Map.mapi (fun def_key scope_def -> - let def = scope_def.Desugared.Ast.scope_def_rules in + let def = scope_def.D.scope_def_rules in let is_cond = scope_def.scope_def_is_condition in match def_key with - | Desugared.Ast.ScopeDef.Var _ -> assert false (* should not happen *) - | Desugared.Ast.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) -> + | D.ScopeDef.Var _ -> assert false (* should not happen *) + | D.ScopeDef.SubScopeVar (sscope, sub_scope_var, pos) -> (* This definition redefines a variable of the correct subscope. But we have to check that this redefinition is allowed with respect to the io parameters of that subscope variable. *) (match - Mark.remove scope_def.Desugared.Ast.scope_def_io.io_input + Mark.remove scope_def.D.scope_def_io.io_input with | NoInput -> Message.raise_multispanned_error @@ -245,23 +259,23 @@ let rule_to_exception_graph (scope : Desugared.Ast.scope) = function was provided." | _ -> ()); let exc_graph = def_to_exception_graph def_key def in - let var_pos = Desugared.Ast.ScopeDef.get_position def_key in + let var_pos = D.ScopeDef.get_position def_key in exc_graph, sub_scope_var, var_pos) sub_scope_vars_redefs_candidates in List.fold_left (fun exc_graphs (new_exc_graph, subscope_var, var_pos) -> - Desugared.Ast.ScopeDef.Map.add - (Desugared.Ast.ScopeDef.SubScopeVar + D.ScopeDef.Map.add + (D.ScopeDef.SubScopeVar (sub_scope_index, subscope_var, var_pos)) new_exc_graph exc_graphs) - Desugared.Ast.ScopeDef.Map.empty - (Desugared.Ast.ScopeDef.Map.values sub_scope_vars_redefs) + D.ScopeDef.Map.empty + (D.ScopeDef.Map.values sub_scope_vars_redefs) | Assertion _ -> - Desugared.Ast.ScopeDef.Map.empty (* no exceptions for assertions *) + D.ScopeDef.Map.empty (* no exceptions for assertions *) -let scope_to_exception_graphs (scope : Desugared.Ast.scope) : - Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t = +let scope_to_exception_graphs (scope : D.scope) : + Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t = let scope_dependencies = Desugared.Dependency.build_scope_dependencies scope in @@ -272,25 +286,25 @@ let scope_to_exception_graphs (scope : Desugared.Ast.scope) : List.fold_left (fun exceptions_graphs scope_def_key -> let new_exceptions_graphs = rule_to_exception_graph scope scope_def_key in - Desugared.Ast.ScopeDef.Map.union + D.ScopeDef.Map.union (fun _ _ _ -> assert false (* there should not be key conflicts *)) new_exceptions_graphs exceptions_graphs) - Desugared.Ast.ScopeDef.Map.empty scope_ordering + D.ScopeDef.Map.empty scope_ordering -let build_exceptions_graph (pgrm : Desugared.Ast.program) : - Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t = +let build_exceptions_graph (pgrm : D.program) : + Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t = ScopeName.Map.fold (fun _ scope exceptions_graph -> let new_exceptions_graphs = scope_to_exception_graphs scope in - Desugared.Ast.ScopeDef.Map.union + D.ScopeDef.Map.union (fun _ _ _ -> assert false (* key conflicts should not happen*)) new_exceptions_graphs exceptions_graph) - pgrm.program_scopes Desugared.Ast.ScopeDef.Map.empty + pgrm.program_scopes D.ScopeDef.Map.empty (** Transforms a flat list of rules into a tree, taking into account the priorities declared between rules *) let def_map_to_tree - (def : Desugared.Ast.rule RuleName.Map.t) + (def : D.rule RuleName.Map.t) (exc_graph : Desugared.Dependency.ExceptionsDependencies.t) : rule_tree list = (* we start by the base cases: they are the vertices which have no @@ -328,7 +342,7 @@ let rec rule_tree_to_expr ~(is_reentrant_var : bool) (ctx : ctx) (def_pos : Pos.t) - (params : Desugared.Ast.expr Var.t list option) + (params : D.expr Var.t list option) (tree : rule_tree) : untyped Ast.expr boxed = let emark = Untyped { pos = def_pos } in let exceptions, base_rules = @@ -338,9 +352,9 @@ let rec rule_tree_to_expr the whole rule tree into a function, we need to perform some alpha-renaming of all the expressions *) let substitute_parameter - (e : Desugared.Ast.expr boxed) - (rule : Desugared.Ast.rule) : Desugared.Ast.expr boxed = - match params, rule.Desugared.Ast.rule_parameter with + (e : D.expr boxed) + (rule : D.rule) : D.expr boxed = + match params, rule.D.rule_parameter with | Some new_params, Some (old_params_with_types, _) -> let old_params, _ = List.split old_params_with_types in let old_params = Array.of_list (List.map Mark.remove old_params) in @@ -377,15 +391,15 @@ let rec rule_tree_to_expr in let base_just_list = List.map - (fun rule -> substitute_parameter rule.Desugared.Ast.rule_just rule) + (fun rule -> substitute_parameter rule.D.rule_just rule) base_rules in let base_cons_list = List.map - (fun rule -> substitute_parameter rule.Desugared.Ast.rule_cons rule) + (fun rule -> substitute_parameter rule.D.rule_cons rule) base_rules in - let translate_and_unbox_list (list : Desugared.Ast.expr boxed list) : + let translate_and_unbox_list (list : D.expr boxed list) : untyped Ast.expr boxed list = List.map (fun e -> @@ -419,7 +433,7 @@ let rec rule_tree_to_expr (Expr.elit (LBool true) emark) default_containing_base_cases emark in - match params, (List.hd base_rules).Desugared.Ast.rule_parameter with + match params, (List.hd base_rules).D.rule_parameter with | None, None -> default | Some new_params, Some (ls, _) -> let _, tys = List.split ls in @@ -449,33 +463,33 @@ let translate_def ~(is_cond : bool) ~(is_subscope_var : bool) (ctx : ctx) - (def_info : Desugared.Ast.ScopeDef.t) - (def : Desugared.Ast.rule RuleName.Map.t) + (def_info : D.ScopeDef.t) + (def : D.rule RuleName.Map.t) (params : (Uid.MarkedString.info * typ) list Mark.pos option) (typ : typ) - (io : Desugared.Ast.io) + (io : D.io) (exc_graph : Desugared.Dependency.ExceptionsDependencies.t) : untyped Ast.expr boxed = (* Here, we have to transform this list of rules into a default tree. *) let top_list = def_map_to_tree def exc_graph in let is_input = - match Mark.remove io.Desugared.Ast.io_input with + match Mark.remove io.D.io_input with | OnlyInput -> true | _ -> false in let is_reentrant = - match Mark.remove io.Desugared.Ast.io_input with + match Mark.remove io.D.io_input with | Reentrant -> true | _ -> false in - let top_value : Desugared.Ast.rule option = + let top_value : D.rule option = if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then (* We add the bottom [false] value for conditions, only for the scope where the condition is declared. Except when the variable is an input, where we want the [false] to be added at each caller parent scope. *) Some - (Desugared.Ast.always_false_rule - (Desugared.Ast.ScopeDef.get_position def_info) + (D.always_false_rule + (D.ScopeDef.get_position def_info) params) else None in @@ -505,7 +519,7 @@ let translate_def will not be provided by the calee scope, it has to be placed in the caller. *) then - let m = Untyped { pos = Desugared.Ast.ScopeDef.get_position def_info } in + let m = Untyped { pos = D.ScopeDef.get_position def_info } in let empty_error = Expr.eemptyerror m in match params with | Some (ps, _) -> @@ -517,7 +531,7 @@ let translate_def | _ -> empty_error else rule_tree_to_expr ~toplevel:true ~is_reentrant_var:is_reentrant ctx - (Desugared.Ast.ScopeDef.get_position def_info) + (D.ScopeDef.get_position def_info) (Option.map (fun (ps, _) -> (List.map (fun (lbl, _) -> Var.make (Mark.remove lbl))) ps) @@ -526,7 +540,7 @@ let translate_def | [], None -> (* In this case, there are no rules to define the expression and no default value so we put an empty rule. *) - Leaf [Desugared.Ast.empty_rule (Mark.get typ) params] + Leaf [D.empty_rule (Mark.get typ) params] | [], Some top_value -> (* In this case, there are no rules to define the expression but a default value so we put it. *) @@ -537,35 +551,35 @@ let translate_def Node (top_list, [top_value]) | [top_tree], None -> top_tree | _, None -> - Node (top_list, [Desugared.Ast.empty_rule (Mark.get typ) params])) + Node (top_list, [D.empty_rule (Mark.get typ) params])) let translate_rule ctx - (scope : Desugared.Ast.scope) + (scope : D.scope) (exc_graphs : - Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t) + Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t) = function | Desugared.Dependency.Vertex.Var (var, state) -> ( let scope_def = - Desugared.Ast.ScopeDef.Map.find - (Desugared.Ast.ScopeDef.Var (var, state)) + D.ScopeDef.Map.find + (D.ScopeDef.Var (var, state)) scope.scope_defs in let var_def = scope_def.D.scope_def_rules in let var_params = scope_def.D.scope_def_parameters in let var_typ = scope_def.D.scope_def_typ in let is_cond = scope_def.D.scope_def_is_condition in - match Mark.remove scope_def.Desugared.Ast.scope_def_io.io_input with + match Mark.remove scope_def.D.scope_def_io.io_input with | OnlyInput when not (RuleName.Map.is_empty var_def) -> assert false (* error already raised *) | OnlyInput -> [] (* we do not provide any definition for an input-only variable *) | _ -> - let scope_def_key = Desugared.Ast.ScopeDef.Var (var, state) in + let scope_def_key = D.ScopeDef.Var (var, state) in let expr_def = translate_def ctx scope_def_key var_def var_params var_typ - scope_def.Desugared.Ast.scope_def_io - (Desugared.Ast.ScopeDef.Map.find scope_def_key exc_graphs) + scope_def.D.scope_def_io + (D.ScopeDef.Map.find scope_def_key exc_graphs) ~is_cond ~is_subscope_var:false in let scope_var = @@ -577,10 +591,10 @@ let translate_rule [ Ast.Definition ( ( ScopelangScopeVar - (scope_var, Mark.get (ScopeVar.get_info scope_var)), + { name = scope_var, Mark.get (ScopeVar.get_info scope_var) }, Mark.get (ScopeVar.get_info scope_var) ), var_typ, - scope_def.Desugared.Ast.scope_def_io, + scope_def.D.scope_def_io, Expr.unbox expr_def ); ]) | Desugared.Dependency.Vertex.SubScope sub_scope_index -> @@ -590,17 +604,17 @@ let translate_rule SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes in let sub_scope_vars_redefs_candidates = - Desugared.Ast.ScopeDef.Map.filter + D.ScopeDef.Map.filter (fun def_key scope_def -> match def_key with - | Desugared.Ast.ScopeDef.Var _ -> false - | Desugared.Ast.ScopeDef.SubScopeVar (sub_scope_index', _, _) -> + | D.ScopeDef.Var _ -> false + | D.ScopeDef.SubScopeVar (sub_scope_index', _, _) -> sub_scope_index = sub_scope_index' (* We exclude subscope variables that have 0 re-definitions and are not visible in the input of the subscope *) && not ((match - Mark.remove scope_def.Desugared.Ast.scope_def_io.io_input + Mark.remove scope_def.D.scope_def_io.io_input with | NoInput -> true | _ -> false) @@ -608,19 +622,19 @@ let translate_rule scope.scope_defs in let sub_scope_vars_redefs = - Desugared.Ast.ScopeDef.Map.mapi + D.ScopeDef.Map.mapi (fun def_key scope_def -> - let def = scope_def.Desugared.Ast.scope_def_rules in + let def = scope_def.D.scope_def_rules in let def_typ = scope_def.scope_def_typ in let is_cond = scope_def.scope_def_is_condition in match def_key with - | Desugared.Ast.ScopeDef.Var _ -> assert false (* should not happen *) - | Desugared.Ast.ScopeDef.SubScopeVar (_, sub_scope_var, var_pos) -> + | D.ScopeDef.Var _ -> assert false (* should not happen *) + | D.ScopeDef.SubScopeVar (_, sub_scope_var, var_pos) -> (* This definition redefines a variable of the correct subscope. But we have to check that this redefinition is allowed with respect to the io parameters of that subscope variable. *) (match - Mark.remove scope_def.Desugared.Ast.scope_def_io.io_input + Mark.remove scope_def.D.scope_def_io.io_input with | NoInput -> assert false (* error already raised *) | OnlyInput when RuleName.Map.is_empty def && not is_cond -> @@ -630,17 +644,19 @@ let translate_rule redefinition to a proper Scopelang term. *) let expr_def = translate_def ctx def_key def scope_def.D.scope_def_parameters - def_typ scope_def.Desugared.Ast.scope_def_io - (Desugared.Ast.ScopeDef.Map.find def_key exc_graphs) + def_typ scope_def.D.scope_def_io + (D.ScopeDef.Map.find def_key exc_graphs) ~is_cond ~is_subscope_var:true in - let subscop_real_name = + let subscop_path, subscop_real_name = SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes in Ast.Definition - ( ( SubScopeVar - ( subscop_real_name, - (sub_scope_index, var_pos), + ( ( SubScopeVar { + path = subscop_path; + scope = subscop_real_name; + alias = sub_scope_index, var_pos; + var = match ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping with @@ -648,15 +664,15 @@ let translate_rule | States states -> (* When defining a sub-scope variable, we always define its first state in the sub-scope. *) - snd (List.hd states), var_pos ), + snd (List.hd states), var_pos }, var_pos ), def_typ, - scope_def.Desugared.Ast.scope_def_io, + scope_def.D.scope_def_io, Expr.unbox expr_def )) sub_scope_vars_redefs_candidates in let sub_scope_vars_redefs = - Desugared.Ast.ScopeDef.Map.values sub_scope_vars_redefs + D.ScopeDef.Map.values sub_scope_vars_redefs in sub_scope_vars_redefs @ [ @@ -668,43 +684,23 @@ let translate_rule ] | Assertion a_name -> let assertion_expr = - Desugared.Ast.AssertionName.Map.find a_name scope.scope_assertions + D.AssertionName.Map.find a_name scope.scope_assertions in (* we unbox here because assertions do not have free variables (at this point Bindlib variables are only for fuhnction parameters)*) let assertion_expr = translate_expr ctx (Expr.unbox assertion_expr) in [Ast.Assertion (Expr.unbox assertion_expr)] -(** Translates a scope *) -let translate_scope - (ctx : ctx) - (scope : Desugared.Ast.scope) - (exc_graphs : - Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t) - : untyped Ast.scope_decl = - let scope_dependencies = - Desugared.Dependency.build_scope_dependencies scope - in - Desugared.Dependency.check_for_cycle scope scope_dependencies; - let scope_ordering = - Desugared.Dependency.correct_computation_ordering scope_dependencies - in - let scope_decl_rules = - List.fold_left - (fun scope_decl_rules scope_def_key -> - let new_rules = translate_rule ctx scope exc_graphs scope_def_key in - scope_decl_rules @ new_rules) - [] scope_ordering - in +let translate_scope_interface ctx scope = let scope_sig = ScopeVar.Map.fold - (fun var (states : Desugared.Ast.var_or_states) acc -> + (fun var (states : D.var_or_states) acc -> match states with | WholeVar -> let scope_def = - Desugared.Ast.ScopeDef.Map.find - (Desugared.Ast.ScopeDef.Var (var, None)) - scope.scope_defs + D.ScopeDef.Map.find + (D.ScopeDef.Var (var, None)) + scope.D.scope_defs in let typ = scope_def.scope_def_typ in ScopeVar.Map.add @@ -720,9 +716,9 @@ let translate_scope List.fold_left (fun acc (state : StateName.t) -> let scope_def = - Desugared.Ast.ScopeDef.Map.find - (Desugared.Ast.ScopeDef.Var (var, Some state)) - scope.scope_defs + D.ScopeDef.Map.find + (D.ScopeDef.Var (var, Some state)) + scope.D.scope_defs in ScopeVar.Map.add (match ScopeVar.Map.find var ctx.scope_var_mapping with @@ -734,92 +730,143 @@ let translate_scope scope.scope_vars ScopeVar.Map.empty in let pos = Mark.get (ScopeName.get_info scope.scope_uid) in + Mark.add pos { Ast.scope_decl_name = scope.scope_uid; - Ast.scope_decl_rules; + Ast.scope_decl_rules = []; Ast.scope_sig; - Ast.scope_mark = Untyped { pos }; Ast.scope_options = scope.scope_options; } +let translate_scope + (ctx : ctx) + (exc_graphs : + Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t) + (scope : D.scope) + : untyped Ast.scope_decl Mark.pos = + let scope_dependencies = + Desugared.Dependency.build_scope_dependencies scope + in + Desugared.Dependency.check_for_cycle scope scope_dependencies; + let scope_ordering = + Desugared.Dependency.correct_computation_ordering scope_dependencies + in + let scope_decl_rules = + List.fold_left + (fun scope_decl_rules scope_def_key -> + let new_rules = translate_rule ctx scope exc_graphs scope_def_key in + scope_decl_rules @ new_rules) + [] scope_ordering + in + Mark.map (fun s -> { s with Ast.scope_decl_rules }) + (translate_scope_interface ctx scope) + (** {1 API} *) let translate_program - (pgrm : Desugared.Ast.program) + (desugared : D.program) (exc_graphs : - Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t) + Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t) : untyped Ast.program = (* First we give mappings to all the locations between Desugared and This involves creating a new Scopelang scope variable for every state of a Desugared variable. *) - let ctx = + let rec make_ctx desugared = + let modules = ModuleName.Map.map make_ctx desugared.D.program_modules in (* Todo: since we rename all scope vars at this point, it would be better to have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *) ScopeName.Map.fold (fun _scope scope_decl ctx -> - ScopeVar.Map.fold - (fun scope_var (states : Desugared.Ast.var_or_states) ctx -> - let var_name, var_pos = ScopeVar.get_info scope_var in - let new_var = - match states with - | Desugared.Ast.WholeVar -> - WholeVar (ScopeVar.fresh (var_name, var_pos)) - | States states -> - let var_prefix = var_name ^ "_" in - let state_var state = - ScopeVar.fresh - (Mark.map (( ^ ) var_prefix) (StateName.get_info state)) - in - States (List.map (fun state -> state, state_var state) states) - in - { - ctx with - scope_var_mapping = - ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping; - }) - scope_decl.Desugared.Ast.scope_vars ctx) - pgrm.Desugared.Ast.program_scopes + ScopeVar.Map.fold + (fun scope_var (states : D.var_or_states) ctx -> + let var_name, var_pos = ScopeVar.get_info scope_var in + let new_var = + match states with + | D.WholeVar -> + WholeVar (ScopeVar.fresh (var_name, var_pos)) + | States states -> + let var_prefix = var_name ^ "_" in + let state_var state = + ScopeVar.fresh + (Mark.map (( ^ ) var_prefix) (StateName.get_info state)) + in + States (List.map (fun state -> state, state_var state) states) + in + { + ctx with + scope_var_mapping = + ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping; + }) + scope_decl.D.scope_vars ctx) + desugared.D.program_scopes { scope_var_mapping = ScopeVar.Map.empty; var_mapping = Var.Map.empty; - decl_ctx = pgrm.program_ctx; + decl_ctx = desugared.program_ctx; + modules; } in - let ctx_scopes = - ScopeName.Map.map - (fun out_str -> - let out_struct_fields = - ScopeVar.Map.fold - (fun var fld out_map -> - let var' = - match ScopeVar.Map.find var ctx.scope_var_mapping with - | WholeVar v -> v - | States l -> snd (List.hd (List.rev l)) - in - ScopeVar.Map.add var' fld out_map) - out_str.out_struct_fields ScopeVar.Map.empty - in - { out_str with out_struct_fields }) - pgrm.Desugared.Ast.program_ctx.ctx_scopes + let ctx = make_ctx desugared in + let rec process_decl_ctx ctx decl_ctx = + let ctx_scopes = + ScopeName.Map.map + (fun out_str -> + let out_struct_fields = + ScopeVar.Map.fold + (fun var fld out_map -> + let var' = + match ScopeVar.Map.find var ctx.scope_var_mapping with + | WholeVar v -> v + | States l -> snd (List.hd (List.rev l)) + in + ScopeVar.Map.add var' fld out_map) + out_str.out_struct_fields ScopeVar.Map.empty + in + { out_str with out_struct_fields }) + decl_ctx.ctx_scopes + in + { decl_ctx with + ctx_modules = + ModuleName.Map.mapi (fun modname decl_ctx -> + let ctx = ModuleName.Map.find modname ctx.modules in + process_decl_ctx ctx decl_ctx) + decl_ctx.ctx_modules; + ctx_scopes; } in - let program_scopes = - ScopeName.Map.fold - (fun scope_name scope new_program_scopes -> - let new_program_scope = translate_scope ctx scope exc_graphs in - ScopeName.Map.add scope_name new_program_scope new_program_scopes) - pgrm.program_scopes ScopeName.Map.empty + let rec process_modules program_ctx desugared = + ModuleName.Map.mapi (fun modname m_desugared -> + let ctx = ModuleName.Map.find modname ctx.modules in + { + Ast.program_topdefs = TopdefName.Map.empty; + program_scopes = + ScopeName.Map.map + (translate_scope_interface ctx) + m_desugared.D.program_scopes; + program_ctx; + program_modules = + process_modules + (ModuleName.Map.find modname program_ctx.ctx_modules) + m_desugared; + }) + desugared.D.program_modules in + let program_ctx = process_decl_ctx ctx desugared.D.program_ctx in + let program_modules = process_modules program_ctx desugared in let program_topdefs = TopdefName.Map.mapi (fun id -> function - | Some e, ty -> Expr.unbox (translate_expr ctx e), ty - | None, (_, pos) -> - Message.raise_spanned_error pos "No definition found for %a" - TopdefName.format id) - pgrm.program_topdefs + | Some e, ty -> Expr.unbox (translate_expr ctx e), ty + | None, (_, pos) -> + Message.raise_spanned_error pos "No definition found for %a" + TopdefName.format id) + desugared.program_topdefs + in + let program_scopes = + ScopeName.Map.map (translate_scope ctx exc_graphs) desugared.D.program_scopes in { Ast.program_topdefs; - program_scopes; - program_ctx = { pgrm.program_ctx with ctx_scopes }; + Ast.program_scopes; + Ast.program_ctx; + Ast.program_modules; } diff --git a/compiler/scopelang/print.ml b/compiler/scopelang/print.ml index 90bd3973..051ba69b 100644 --- a/compiler/scopelang/print.ml +++ b/compiler/scopelang/print.ml @@ -22,8 +22,9 @@ let struc ctx (fmt : Format.formatter) (name : StructName.t) - (fields : typ StructField.Map.t) : unit = - Format.fprintf fmt "%a %a %a %a@\n@[ %a@]@\n%a" Print.keyword "struct" + (path, fields : path * typ StructField.Map.t) : unit = + Format.fprintf fmt "%a %a%a %a %a@\n@[ %a@]@\n%a" Print.keyword "struct" + Print.path path StructName.format name Print.punctuation "=" Print.punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") @@ -37,8 +38,9 @@ let enum ctx (fmt : Format.formatter) (name : EnumName.t) - (cases : typ EnumConstructor.Map.t) : unit = - Format.fprintf fmt "%a %a %a @\n@[ %a@]" Print.keyword "enum" + (path, cases : path * typ EnumConstructor.Map.t) : unit = + Format.fprintf fmt "%a %a%a %a @\n@[ %a@]" Print.keyword "enum" + Print.path path EnumName.format name Print.punctuation "=" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") @@ -48,7 +50,7 @@ let enum (Print.typ ctx) typ)) (EnumConstructor.Map.bindings cases) -let scope ?debug ctx fmt (name, decl) = +let scope ?debug ctx fmt (name, (decl, _pos)) = Format.fprintf fmt "@[%a@ %a@ %a@ %a@ %a@]@\n@[ %a@]" Print.keyword "let" Print.keyword "scope" ScopeName.format name (Format.pp_print_list ~pp_sep:Format.pp_print_space @@ -78,7 +80,7 @@ let scope ?debug ctx fmt (name, decl) = (fun fmt e -> match Mark.remove loc with | SubScopeVar _ | ToplevelVar _ -> Print.expr () fmt e - | ScopelangScopeVar v -> ( + | ScopelangScopeVar { name = v } -> ( match Mark.remove (snd (ScopeVar.Map.find (Mark.remove v) decl.scope_sig)) @@ -92,8 +94,9 @@ let scope ?debug ctx fmt (name, decl) = | Assertion e -> Format.fprintf fmt "%a %a" Print.keyword "assert" (Print.expr ?debug ()) e - | Call (scope_name, subscope_name, _) -> - Format.fprintf fmt "%a %a%a%a%a" Print.keyword "call" + | Call ((scope_path, scope_name), subscope_name, _) -> + Format.fprintf fmt "%a %a%a%a%a%a" Print.keyword "call" + Print.path scope_path ScopeName.format scope_name Print.punctuation "[" SubScopeName.format subscope_name Print.punctuation "]")) decl.scope_decl_rules diff --git a/compiler/scopelang/print.mli b/compiler/scopelang/print.mli index 4ec2f4fe..eb1767ea 100644 --- a/compiler/scopelang/print.mli +++ b/compiler/scopelang/print.mli @@ -14,11 +14,13 @@ License for the specific language governing permissions and limitations under the License. *) +open Catala_utils + val scope : ?debug:bool (** [true] for debug printing *) -> Shared_ast.decl_ctx -> Format.formatter -> - Shared_ast.ScopeName.t * 'm Ast.scope_decl -> + Shared_ast.ScopeName.t * 'm Ast.scope_decl Mark.pos -> unit val program : diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 26471432..5a3462bb 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -22,6 +22,9 @@ open Catala_utils module Runtime = Runtime_ocaml.Runtime +module ModuleName = String +(* TODO: should probably be turned into an Uid once we implement module import directives; that will incur an additional resolution work on all paths though *) + module ScopeName = Uid.Gen () module TopdefName = Uid.Gen () module StructName = Uid.Gen () @@ -311,6 +314,8 @@ type except = ConflictError | EmptyError | NoValueProvided | Crash type untyped = { pos : Pos.t } [@@caml.unboxed] type typed = { pos : Pos.t; ty : typ } type 'a custom = { pos : Pos.t; custom : 'a } +type nil = | + (** Using empty markings will ensure terms can't be constructed: used for example in interfaces to ensure that they don't contain any expressions *) (** The generic type of AST markings. Using a GADT allows functions to be polymorphic in the marking, but still do transformations on types when @@ -339,19 +344,26 @@ type lit = | LDate of date | LDuration of duration +type path = ModuleName.t Mark.pos list + +(** External references are resolved to strings that point to functions or constants in the end, but we need to keep different references for typing *) +type external_ref = + | External_value of TopdefName.t + | External_scope of ScopeName.t + (** Locations are handled differently in [desugared] and [scopelang] *) type 'a glocation = | DesugaredScopeVar : - ScopeVar.t Mark.pos * StateName.t option + { name: ScopeVar.t Mark.pos; state: StateName.t option } -> < scopeVarStates : yes ; .. > glocation | ScopelangScopeVar : - ScopeVar.t Mark.pos + { name: ScopeVar.t Mark.pos } -> < scopeVarSimpl : yes ; .. > glocation | SubScopeVar : - ScopeName.t * SubScopeName.t Mark.pos * ScopeVar.t Mark.pos + { path: path; scope: ScopeName.t; alias: SubScopeName.t Mark.pos; var: ScopeVar.t Mark.pos } -> < explicitScopes : yes ; .. > glocation | ToplevelVar : - TopdefName.t Mark.pos + { path: path; name: TopdefName.t Mark.pos } -> < explicitScopes : yes ; .. > glocation type ('a, 'm) gexpr = (('a, 'm) naked_gexpr, 'm) marked @@ -392,7 +404,6 @@ and ('a, 'b, 'm) base_gexpr = -> ('a, (< .. > as 'b), 'm) base_gexpr | EArray : ('a, 'm) gexpr list -> ('a, < .. >, 'm) base_gexpr | EVar : ('a, 'm) naked_gexpr Bindlib.var -> ('a, _, 'm) base_gexpr - | EExternal : Qident.t -> ('a, < .. >, 't) base_gexpr | EAbs : { binder : (('a, 'a, 'm) base_gexpr, ('a, 'm) gexpr) Bindlib.mbinder; tys : typ list; @@ -431,11 +442,13 @@ and ('a, 'b, 'm) base_gexpr = (* Early stages *) | ELocation : 'b glocation -> ('a, (< .. > as 'b), 'm) base_gexpr | EScopeCall : { + path : path; scope : ScopeName.t; args : ('a, 'm) gexpr ScopeVar.Map.t; } -> ('a, < explicitScopes : yes ; .. >, 'm) base_gexpr | EDStructAccess : { + path : path; name_opt : StructName.t option; e : ('a, 'm) gexpr; field : Ident.t; @@ -450,6 +463,7 @@ and ('a, 'b, 'm) base_gexpr = -> ('a, < resolvedNames : yes ; .. >, 'm) base_gexpr (** Resolved struct/enums, after [desugared] *) (* Lambda-like *) + | EExternal : { path: path; name: external_ref Mark.pos} -> ('a, < explicitScopes: no ; .. >, 't) base_gexpr | EAssert : ('a, 'm) gexpr -> ('a, < assertions : yes ; .. >, 'm) base_gexpr (* Default terms *) | EDefault : { @@ -562,10 +576,11 @@ type 'e code_item_list = | Nil | Cons of 'e code_item * ('e, 'e code_item_list) binder -type struct_ctx = typ StructField.Map.t StructName.Map.t -type enum_ctx = typ EnumConstructor.Map.t EnumName.Map.t +type struct_ctx = (path * typ StructField.Map.t) StructName.Map.t +type enum_ctx = (path * typ EnumConstructor.Map.t) EnumName.Map.t -type scope_out_struct = { +type scope_info = { + in_struct_name : StructName.t; out_struct_name : StructName.t; out_struct_fields : StructField.t ScopeVar.Map.t; } @@ -575,8 +590,12 @@ type decl_ctx = { ctx_structs : struct_ctx; ctx_struct_fields : StructField.t StructName.Map.t Ident.Map.t; (** needed for disambiguation (desugared -> scope) *) - ctx_scopes : scope_out_struct ScopeName.Map.t; - ctx_modules : typ Qident.Map.t; + ctx_scopes : scope_info ScopeName.Map.t; + ctx_topdefs : typ TopdefName.Map.t; + ctx_modules : decl_ctx ModuleName.Map.t; } -type 'e program = { decl_ctx : decl_ctx; code_items : 'e code_item_list } +type 'e program = { + decl_ctx : decl_ctx; + code_items : 'e code_item_list; +} diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index 59067d74..55ed0a5d 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -109,7 +109,7 @@ let subst binder vars = Bindlib.msubst binder (Array.of_list (List.map Mark.remove vars)) let evar v mark = Mark.add mark (Bindlib.box_var v) -let eexternal eref mark = Mark.add mark (Bindlib.box (EExternal eref)) +let eexternal ~path ~name mark = Mark.add mark (Bindlib.box (EExternal {path; name})) let etuple args = Box.appn args @@ fun args -> ETuple args let etupleaccess e index size = @@ -146,31 +146,31 @@ let ecustom obj targs tret mark = let elocation loc = Box.app0 @@ ELocation loc -let estruct name (fields : ('a, 't) boxed_gexpr StructField.Map.t) mark = +let estruct ~name ~(fields : ('a, 't) boxed_gexpr StructField.Map.t) mark = Mark.add mark @@ Bindlib.box_apply (fun fields -> EStruct { name; fields }) (Box.lift_struct (StructField.Map.map Box.lift fields)) -let edstructaccess e field name_opt = - Box.app1 e @@ fun e -> EDStructAccess { name_opt; e; field } +let edstructaccess ~path ~name_opt ~field ~e = + Box.app1 e @@ fun e -> EDStructAccess { path; name_opt; field; e } -let estructaccess e field name = - Box.app1 e @@ fun e -> EStructAccess { name; e; field } +let estructaccess ~name ~field ~e = + Box.app1 e @@ fun e -> EStructAccess { name; field; e } -let einj e cons name = Box.app1 e @@ fun e -> EInj { name; e; cons } +let einj ~name ~cons ~e = Box.app1 e @@ fun e -> EInj { name; cons; e } -let ematch e name cases mark = +let ematch ~name ~e ~cases mark = Mark.add mark @@ Bindlib.box_apply2 (fun e cases -> EMatch { name; e; cases }) (Box.lift e) (Box.lift_enum (EnumConstructor.Map.map Box.lift cases)) -let escopecall scope args mark = +let escopecall ~path ~scope ~args mark = Mark.add mark @@ Bindlib.box_apply - (fun args -> EScopeCall { scope; args }) + (fun args -> EScopeCall { path; scope; args }) (Box.lift_scope_vars (ScopeVar.Map.map Box.lift args)) (* - Manipulation of marks - *) @@ -272,7 +272,7 @@ let map | EOp { op; tys } -> eop op tys m | EArray args -> earray (List.map f args) m | EVar v -> evar (Var.translate v) m - | EExternal eref -> eexternal eref m + | EExternal { path; name } -> eexternal ~path ~name m | EAbs { binder; tys } -> let vars, body = Bindlib.unmbind binder in let body = f body in @@ -282,7 +282,7 @@ let map eifthenelse (f cond) (f etrue) (f efalse) m | ETuple args -> etuple (List.map f args) m | ETupleAccess { e; index; size } -> etupleaccess (f e) index size m - | EInj { e; name; cons } -> einj (f e) cons name m + | EInj { name; cons; e } -> einj ~name ~cons ~e:(f e) m | EAssert e1 -> eassert (f e1) m | EDefault { excepts; just; cons } -> edefault (List.map f excepts) (f just) (f cons) m @@ -293,16 +293,17 @@ let map | ELocation loc -> elocation loc m | EStruct { name; fields } -> let fields = StructField.Map.map f fields in - estruct name fields m - | EDStructAccess { e; field; name_opt } -> - edstructaccess (f e) field name_opt m - | EStructAccess { e; field; name } -> estructaccess (f e) field name m - | EMatch { e; name; cases } -> + estruct ~name ~fields m + | EDStructAccess { path; name_opt; field; e } -> + edstructaccess ~path ~name_opt ~field ~e:(f e) m + | EStructAccess { name; field; e } -> + estructaccess ~name ~field ~e:(f e) m + | EMatch { name; e; cases } -> let cases = EnumConstructor.Map.map f cases in - ematch (f e) name cases m - | EScopeCall { scope; args } -> - let fields = ScopeVar.Map.map f args in - escopecall scope fields m + ematch ~name ~e:(f e) ~cases m + | EScopeCall { path; scope; args } -> + let args = ScopeVar.Map.map f args in + escopecall ~path ~scope ~args m | ECustom { obj; targs; tret } -> ecustom obj targs tret m let rec map_top_down ~f e = map ~f:(map_top_down ~f) (f e) @@ -369,7 +370,7 @@ let map_gather let acc, args = lfoldmap args in acc, earray args m | EVar v -> acc, evar (Var.translate v) m - | EExternal eref -> acc, eexternal eref m + | EExternal { path; name } -> acc, eexternal ~path ~name m | EAbs { binder; tys } -> let vars, body = Bindlib.unmbind binder in let acc, body = f body in @@ -386,9 +387,9 @@ let map_gather | ETupleAccess { e; index; size } -> let acc, e = f e in acc, etupleaccess e index size m - | EInj { e; name; cons } -> + | EInj { name; cons; e } -> let acc, e = f e in - acc, einj e cons name m + acc, einj ~name ~cons ~e m | EAssert e -> let acc, e = f e in acc, eassert e m @@ -416,14 +417,14 @@ let map_gather fields (acc, StructField.Map.empty) in - acc, estruct name fields m - | EDStructAccess { e; field; name_opt } -> + acc, estruct ~name ~fields m + | EDStructAccess { path; name_opt; field; e } -> let acc, e = f e in - acc, edstructaccess e field name_opt m - | EStructAccess { e; field; name } -> + acc, edstructaccess ~path ~name_opt ~field ~e m + | EStructAccess { name; field; e } -> let acc, e = f e in - acc, estructaccess e field name m - | EMatch { e; name; cases } -> + acc, estructaccess ~name ~field ~e m + | EMatch { name; e; cases } -> let acc, e = f e in let acc, cases = EnumConstructor.Map.fold @@ -433,8 +434,8 @@ let map_gather cases (acc, EnumConstructor.Map.empty) in - acc, ematch e name cases m - | EScopeCall { scope; args } -> + acc, ematch ~name ~e ~cases m + | EScopeCall { path; scope; args } -> let acc, args = ScopeVar.Map.fold (fun var e (acc, args) -> @@ -442,7 +443,7 @@ let map_gather join acc acc1, ScopeVar.Map.add var e args) args (acc, ScopeVar.Map.empty) in - acc, escopecall scope args m + acc, escopecall ~path ~scope ~args m | ECustom { obj; targs; tret } -> acc, ecustom obj targs tret m (* - *) @@ -515,25 +516,31 @@ let compare_lit (l1 : lit) (l2 : lit) = | LDuration _, _ -> . | _, LDuration _ -> . +let compare_path = + List.compare (Mark.compare ModuleName.compare) + let compare_location (type a) (x : a glocation Mark.pos) (y : a glocation Mark.pos) = match Mark.remove x, Mark.remove y with - | DesugaredScopeVar (vx, None), DesugaredScopeVar (vy, None) - | DesugaredScopeVar (vx, Some _), DesugaredScopeVar (vy, None) - | DesugaredScopeVar (vx, None), DesugaredScopeVar (vy, Some _) -> + | DesugaredScopeVar { name = vx; state = None}, DesugaredScopeVar { name = vy; state = None} + | DesugaredScopeVar { name = vx; state = Some _}, DesugaredScopeVar { name = vy; state = None} + | DesugaredScopeVar { name = vx; state = None}, DesugaredScopeVar { name = vy; state = Some _} -> ScopeVar.compare (Mark.remove vx) (Mark.remove vy) - | DesugaredScopeVar ((x, _), Some sx), DesugaredScopeVar ((y, _), Some sy) -> + | DesugaredScopeVar {name = (x, _); state = Some sx}, DesugaredScopeVar {name = (y, _); state = Some sy} -> let cmp = ScopeVar.compare x y in if cmp = 0 then StateName.compare sx sy else cmp - | ScopelangScopeVar (vx, _), ScopelangScopeVar (vy, _) -> + | ScopelangScopeVar { name = (vx, _) }, ScopelangScopeVar { name = (vy, _) } -> ScopeVar.compare vx vy - | ( SubScopeVar (_, (xsubindex, _), (xsubvar, _)), - SubScopeVar (_, (ysubindex, _), (ysubvar, _)) ) -> + | ( SubScopeVar { alias = (xsubindex, _); var = (xsubvar, _); _}, + SubScopeVar { alias = (ysubindex, _); var = (ysubvar, _); _} ) -> let c = SubScopeName.compare xsubindex ysubindex in if c = 0 then ScopeVar.compare xsubvar ysubvar else c - | ToplevelVar (vx, _), ToplevelVar (vy, _) -> TopdefName.compare vx vy + | ToplevelVar { path = px; name = (vx, _) }, ToplevelVar { path = py; name = (vy, _) } -> + (match compare_path px py with + | 0 -> TopdefName.compare vx vy + | n -> n) | DesugaredScopeVar _, _ -> -1 | _, DesugaredScopeVar _ -> 1 | ScopelangScopeVar _, _ -> -1 @@ -543,21 +550,33 @@ let compare_location | ToplevelVar _, _ -> . | _, ToplevelVar _ -> . +let equal_path = List.equal (Mark.equal ModuleName.equal) let equal_location a b = compare_location a b = 0 let equal_except ex1 ex2 = ex1 = ex2 let compare_except ex1 ex2 = Stdlib.compare ex1 ex2 +let equal_external_ref ref1 ref2 = match ref1, ref2 with + | External_value v1, External_value v2 -> TopdefName.equal v1 v2 + | External_scope s1, External_scope s2 -> ScopeName.equal s1 s2 + | (External_value _ | External_scope _), _ -> false +let compare_external_ref ref1 ref2 = match ref1, ref2 with + | External_value v1, External_value v2 -> TopdefName.compare v1 v2 + | External_scope s1, External_scope s2 -> ScopeName.compare s1 s2 + | External_value _, _ -> -1 + | _, External_value _ -> 1 + | External_scope _, _ -> . + | _, External_scope _ -> . (* weird indentation; see https://github.com/ocaml-ppx/ocamlformat/issues/2143 *) let rec equal_list : 'a. ('a, 't) gexpr list -> ('a, 't) gexpr list -> bool = - fun es1 es2 -> - try List.for_all2 equal es1 es2 with Invalid_argument _ -> false + fun es1 es2 -> List.equal equal es1 es2 and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = fun e1 e2 -> match Mark.remove e1, Mark.remove e2 with | EVar v1, EVar v2 -> Bindlib.eq_vars v1 v2 - | EExternal eref1, EExternal eref2 -> Qident.equal eref1 eref2 + | EExternal { path = p1; name = n1 }, EExternal { path = p2; name = n2 } -> + Mark.equal equal_external_ref n1 n2 && equal_path p1 p2 | ETuple es1, ETuple es2 -> equal_list es1 es2 | ( ETupleAccess { e = e1; index = id1; size = s1 }, ETupleAccess { e = e2; index = id2; size = s2 } ) -> @@ -588,23 +607,25 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = | ( EStruct { name = s1; fields = fields1 }, EStruct { name = s2; fields = fields2 } ) -> StructName.equal s1 s2 && StructField.Map.equal equal fields1 fields2 - | ( EDStructAccess { e = e1; field = f1; name_opt = s1 }, - EDStructAccess { e = e2; field = f2; name_opt = s2 } ) -> - Option.equal StructName.equal s1 s2 && Ident.equal f1 f2 && equal e1 e2 + | ( EDStructAccess { e = e1; field = f1; name_opt = s1; path = p1 }, + EDStructAccess { e = e2; field = f2; name_opt = s2; path = p2 } ) -> + Option.equal StructName.equal s1 s2 && equal_path p1 p2 && Ident.equal f1 f2 && equal e1 e2 | ( EStructAccess { e = e1; field = f1; name = s1 }, EStructAccess { e = e2; field = f2; name = s2 } ) -> StructName.equal s1 s2 && StructField.equal f1 f2 && equal e1 e2 - | EInj { e = e1; cons = c1; name = n1 }, EInj { e = e2; cons = c2; name = n2 } - -> + | EInj { e = e1; cons = c1; name = n1 }, + EInj { e = e2; cons = c2; name = n2 } -> EnumName.equal n1 n2 && EnumConstructor.equal c1 c2 && equal e1 e2 | ( EMatch { e = e1; name = n1; cases = cases1 }, EMatch { e = e2; name = n2; cases = cases2 } ) -> EnumName.equal n1 n2 && equal e1 e2 && EnumConstructor.Map.equal equal cases1 cases2 - | ( EScopeCall { scope = s1; args = fields1 }, - EScopeCall { scope = s2; args = fields2 } ) -> - ScopeName.equal s1 s2 && ScopeVar.Map.equal equal fields1 fields2 + | ( EScopeCall { path = p1; scope = s1; args = fields1 }, + EScopeCall { path = p2; scope = s2; args = fields2 } ) -> + ScopeName.equal s1 s2 && + equal_path p1 p2 && + ScopeVar.Map.equal equal fields1 fields2 | ( ECustom { obj = obj1; targs = targs1; tret = tret1 }, ECustom { obj = obj2; targs = targs2; tret = tret2 } ) -> Type.equal_list targs1 targs2 && Type.equal tret1 tret2 && obj1 == obj2 @@ -635,8 +656,8 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int = List.compare compare a1 a2 | EVar v1, EVar v2 -> Bindlib.compare_vars v1 v2 - | EExternal eref1, EExternal eref2 -> - Qident.compare eref1 eref2 + | EExternal { path = p1; name = n1 }, EExternal { path = p2; name = n2 } -> + compare_path p1 p2 @@< fun () -> Mark.compare compare_external_ref n1 n2 | EAbs {binder=binder1; tys=typs1}, EAbs {binder=binder2; tys=typs2} -> List.compare Type.compare typs1 typs2 @@< fun () -> @@ -649,27 +670,29 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int = compare e1 e2 | ELocation l1, ELocation l2 -> compare_location (Mark.add Pos.no_pos l1) (Mark.add Pos.no_pos l2) - | EStruct {name=name1; fields=field_map1}, - EStruct {name=name2; fields=field_map2} -> + | EStruct {name=name1; fields=field_map1 }, + EStruct {name=name2; fields=field_map2 } -> StructName.compare name1 name2 @@< fun () -> StructField.Map.compare compare field_map1 field_map2 - | EDStructAccess {e=e1; field=field_name1; name_opt=struct_name1}, - EDStructAccess {e=e2; field=field_name2; name_opt=struct_name2} -> + | EDStructAccess {e=e1; field=field_name1; name_opt=struct_name1; path=p1}, + EDStructAccess {e=e2; field=field_name2; name_opt=struct_name2; path=p2} -> compare e1 e2 @@< fun () -> + compare_path p1 p2 @@< fun () -> Ident.compare field_name1 field_name2 @@< fun () -> Option.compare StructName.compare struct_name1 struct_name2 - | EStructAccess {e=e1; field=field_name1; name=struct_name1}, - EStructAccess {e=e2; field=field_name2; name=struct_name2} -> + | EStructAccess {e=e1; field=field_name1; name=struct_name1 }, + EStructAccess {e=e2; field=field_name2; name=struct_name2 } -> compare e1 e2 @@< fun () -> StructField.compare field_name1 field_name2 @@< fun () -> StructName.compare struct_name1 struct_name2 - | EMatch {e=e1; name=name1; cases=emap1}, - EMatch {e=e2; name=name2; cases=emap2} -> + | EMatch {e=e1; name=name1; cases=emap1 }, + EMatch {e=e2; name=name2; cases=emap2 } -> EnumName.compare name1 name2 @@< fun () -> compare e1 e2 @@< fun () -> EnumConstructor.Map.compare compare emap1 emap2 - | EScopeCall {scope=name1; args=field_map1}, - EScopeCall {scope=name2; args=field_map2} -> + | EScopeCall {path = p1; scope=name1; args=field_map1}, + EScopeCall {path = p2; scope=name2; args=field_map2} -> + compare_path p1 p2 @@< fun () -> ScopeName.compare name1 name2 @@< fun () -> ScopeVar.Map.compare compare field_map1 field_map2 | ETuple es1, ETuple es2 -> @@ -679,8 +702,8 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int = Int.compare s1 s2 @@< fun () -> Int.compare n1 n2 @@< fun () -> compare e1 e2 - | EInj {e=e1; name=name1; cons=cons1}, - EInj {e=e2; name=name2; cons=cons2} -> + | EInj {e=e1; name=name1; cons=cons1 }, + EInj {e=e2; name=name2; cons=cons2 } -> EnumName.compare name1 name2 @@< fun () -> EnumConstructor.compare cons1 cons2 @@< fun () -> compare e1 e2 diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index 0f90eeb0..df34350d 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -36,7 +36,7 @@ val rebox : ('a any, 'm) gexpr -> ('a, 'm) boxed_gexpr (** Rebuild the whole term, re-binding all variables and exposing free variables *) val evar : ('a, 'm) gexpr Var.t -> 'm mark -> ('a, 'm) boxed_gexpr -val eexternal : Qident.t -> 'm mark -> ('a any, 'm) boxed_gexpr +val eexternal : path:path -> name:external_ref Mark.pos -> 'm mark -> (< explicitScopes: no; .. >, 'm) boxed_gexpr val bind : ('a, 'm) gexpr Var.t array -> @@ -108,42 +108,44 @@ val eraise : except -> 'm mark -> (< exceptions : yes ; .. >, 'm) boxed_gexpr val elocation : 'a glocation -> 'm mark -> ((< .. > as 'a), 'm) boxed_gexpr val estruct : - StructName.t -> - ('a, 'm) boxed_gexpr StructField.Map.t -> + name: StructName.t -> + fields: ('a, 'm) boxed_gexpr StructField.Map.t -> 'm mark -> ('a any, 'm) boxed_gexpr val edstructaccess : - ('a, 'm) boxed_gexpr -> - Ident.t -> - StructName.t option -> + path: path -> + name_opt: StructName.t option -> + field: Ident.t -> + e: ('a, 'm) boxed_gexpr -> 'm mark -> ((< syntacticNames : yes ; .. > as 'a), 'm) boxed_gexpr val estructaccess : - ('a, 'm) boxed_gexpr -> - StructField.t -> - StructName.t -> + name: StructName.t -> + field: StructField.t -> + e: ('a, 'm) boxed_gexpr -> 'm mark -> ((< resolvedNames : yes ; .. > as 'a), 'm) boxed_gexpr val einj : - ('a, 'm) boxed_gexpr -> - EnumConstructor.t -> - EnumName.t -> + name: EnumName.t -> + cons: EnumConstructor.t -> + e: ('a, 'm) boxed_gexpr -> 'm mark -> ('a any, 'm) boxed_gexpr val ematch : - ('a, 'm) boxed_gexpr -> - EnumName.t -> - ('a, 'm) boxed_gexpr EnumConstructor.Map.t -> + name: EnumName.t -> + e: ('a, 'm) boxed_gexpr -> + cases: ('a, 'm) boxed_gexpr EnumConstructor.Map.t -> 'm mark -> ('a any, 'm) boxed_gexpr val escopecall : - ScopeName.t -> - ('a, 'm) boxed_gexpr ScopeVar.Map.t -> + path:path -> + scope:ScopeName.t -> + args:('a, 'm) boxed_gexpr ScopeVar.Map.t -> 'm mark -> ((< explicitScopes : yes ; .. > as 'a), 'm) boxed_gexpr @@ -382,6 +384,8 @@ val format : Format.formatter -> ('a, 'm) gexpr -> unit val equal_lit : lit -> lit -> bool val compare_lit : lit -> lit -> int +val equal_path : path -> path -> bool +val compare_path : path -> path -> int val equal_location : 'a glocation Mark.pos -> 'a glocation Mark.pos -> bool val compare_location : 'a glocation Mark.pos -> 'a glocation Mark.pos -> int val equal_except : except -> except -> bool diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index 102c32b8..5a9ff702 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -448,6 +448,7 @@ let rec runtime_to_val : m ) | TStruct name -> StructName.Map.find name ctx.ctx_structs + |> snd |> StructField.Map.to_seq |> Seq.map2 (fun o (fld, ty) -> fld, runtime_to_val eval_expr ctx m ty o) @@ -458,9 +459,10 @@ let rec runtime_to_val : (* we only use non-constant constructors of arity 1, which allows us to always use the tag directly (ordered as declared in the constr map), and the field 0 *) + let _path, cons_map = EnumName.Map.find name ctx.ctx_enums in let cons, ty = List.nth - (EnumConstructor.Map.bindings (EnumName.Map.find name ctx.ctx_enums)) + (EnumConstructor.Map.bindings cons_map) (Obj.tag o - Obj.first_non_constant_constructor_tag) in let e = runtime_to_val eval_expr ctx m ty (Obj.field o 0) in @@ -495,7 +497,7 @@ and val_to_runtime : List.map2 (val_to_runtime eval_expr ctx) ts es |> Array.of_list |> Obj.repr | TStruct name1, EStruct { name; fields } -> assert (StructName.equal name name1); - let fld_tys = StructName.Map.find name ctx.ctx_structs in + let _path, fld_tys = StructName.Map.find name ctx.ctx_structs in Seq.map2 (fun (_, ty) (_, v) -> val_to_runtime eval_expr ctx ty v) (StructField.Map.to_seq fld_tys) @@ -504,6 +506,7 @@ and val_to_runtime : |> Obj.repr | TEnum name1, EInj { name; cons; e } -> assert (EnumName.equal name name1); + let _path, cons_map = EnumName.Map.find name ctx.ctx_enums in let rec find_tag n = function | [] -> assert false | (c, ty) :: _ when EnumConstructor.equal c cons -> n, ty @@ -511,7 +514,7 @@ and val_to_runtime : in let tag, ty = find_tag Obj.first_non_constant_constructor_tag - (EnumConstructor.Map.bindings (EnumName.Map.find name ctx.ctx_enums)) + (EnumConstructor.Map.bindings cons_map) in let o = Obj.with_tag tag (Obj.repr (Some ())) in Obj.set_field o 0 (val_to_runtime eval_expr ctx ty e); @@ -546,13 +549,30 @@ let rec evaluate_expr : Message.raise_spanned_error pos "free variable found at evaluation (should not happen if term was \ well-typed)" - | EExternal qid -> ( - match Qident.Map.find_opt qid ctx.ctx_modules with - | None -> - Message.raise_spanned_error pos "Reference to %a could not be resolved" - Qident.format qid - | Some ty -> - let o = Runtime.lookup_value qid in + | EExternal { path; name } -> ( + let ty = + try + let ctx = Program.module_ctx ctx path in + match Mark.remove name with + | External_value name -> + TopdefName.Map.find name ctx.ctx_topdefs + | External_scope name -> + let scope_info = ScopeName.Map.find name ctx.ctx_scopes in + TArrow ([TStruct scope_info.in_struct_name, pos], + (TStruct scope_info.out_struct_name, pos)), + pos + with Not_found -> + Message.raise_spanned_error pos "Reference to %a%a could not be resolved" + Print.path path Print.external_ref name + in + let runtime_path = + List.map Mark.remove path, + match Mark.remove name with + | External_value name -> Mark.remove (TopdefName.get_info name) + | External_scope name -> Mark.remove (ScopeName.get_info name) + (* we have the guarantee that the two cases won't collide because they have different capitalisation rules inherited from the input *) + in + let o = Runtime.lookup_value runtime_path in runtime_to_val evaluate_expr ctx m ty o) | EApp { f = e1; args } -> ( let e1 = evaluate_expr ctx e1 in @@ -792,14 +812,14 @@ let interpret_program_lcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list the types of the scope arguments. For [context] arguments, we can provide an empty thunked term. But for [input] arguments of another type, we cannot provide anything so we have to fail. *) - let taus = StructName.Map.find s_in ctx.ctx_structs in + let _path, taus = StructName.Map.find s_in ctx.ctx_structs in let application_term = StructField.Map.map (fun ty -> match Mark.remove ty with | TOption _ -> - (Expr.einj (Expr.elit LUnit mark_e) Expr.none_constr - Expr.option_enum mark_e + (Expr.einj ~e:(Expr.elit LUnit mark_e) ~cons:Expr.none_constr + ~name:Expr.option_enum mark_e : (_, _) boxed_gexpr) | _ -> Message.raise_spanned_error (Mark.get ty) @@ -812,7 +832,7 @@ let interpret_program_lcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list in let to_interpret = Expr.make_app (Expr.box e) - [Expr.estruct s_in application_term mark_e] + [Expr.estruct ~name:s_in ~fields:application_term mark_e] (Expr.pos e) in match Mark.remove (evaluate_expr ctx (Expr.unbox to_interpret)) with @@ -842,7 +862,7 @@ let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list the types of the scope arguments. For [context] arguments, we can provide an empty thunked term. But for [input] arguments of another type, we cannot provide anything so we have to fail. *) - let taus = StructName.Map.find s_in ctx.ctx_structs in + let _path, taus = StructName.Map.find s_in ctx.ctx_structs in let application_term = StructField.Map.map (fun ty -> @@ -863,7 +883,7 @@ let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list in let to_interpret = Expr.make_app (Expr.box e) - [Expr.estruct s_in application_term mark_e] + [Expr.estruct ~name:s_in ~fields:application_term mark_e] (Expr.pos e) in match Mark.remove (evaluate_expr ctx (Expr.unbox to_interpret)) with diff --git a/compiler/shared_ast/optimizations.ml b/compiler/shared_ast/optimizations.ml index 5f4808f2..e06c1eb6 100644 --- a/compiler/shared_ast/optimizations.ml +++ b/compiler/shared_ast/optimizations.ml @@ -178,7 +178,7 @@ let rec optimize_expr : when false (* TODO: this case is buggy because of the box/unbox manipulation, it should be fixed before removing this [false] value*) - && n1 = n2 + && EnumName.equal n1 n2 && all_match_cases_map_to_same_constructor cases1 n1 -> (* iota-reduction when the matched expression is itself a match of the same enum mapping all constructors to themselves *) @@ -211,7 +211,7 @@ let rec optimize_expr : (* beta reduction when variables not used. *) Mark.remove (Bindlib.msubst binder (List.map fst args |> Array.of_list)) | EStructAccess { name; field; e = EStruct { name = name1; fields }, _ } - when name = name1 -> + when StructName.equal name name1 -> Mark.remove (StructField.Map.find field fields) | EDefault { excepts; just; cons } -> ( (* TODO: mechanically prove each of these optimizations correct *) @@ -353,9 +353,9 @@ let test_iota_reduction_1 () = let consC = EnumConstructor.fresh ("C", Pos.no_pos) in let consD = EnumConstructor.fresh ("D", Pos.no_pos) in let nomark = Untyped { pos = Pos.no_pos } in - let injA = Expr.einj (Expr.evar x nomark) consA enumT nomark in - let injC = Expr.einj (Expr.evar x nomark) consC enumT nomark in - let injD = Expr.einj (Expr.evar x nomark) consD enumT nomark in + let injA = Expr.einj ~e:(Expr.evar x nomark) ~cons:consA ~name:enumT nomark in + let injC = Expr.einj ~e:(Expr.evar x nomark) ~cons:consC ~name:enumT nomark in + let injD = Expr.einj ~e:(Expr.evar x nomark) ~cons:consD ~name:enumT nomark in let cases : ('a, 't) boxed_gexpr EnumConstructor.Map.t = EnumConstructor.Map.of_list [ @@ -363,7 +363,7 @@ let test_iota_reduction_1 () = consB, Expr.eabs (Expr.bind [| x |] injD) [TAny, Pos.no_pos] nomark; ] in - let matchA = Expr.ematch injA enumT cases nomark in + let matchA = Expr.ematch ~e:injA ~name:enumT ~cases nomark in Alcotest.(check string) "same string" "before=match (A x)\n\ @@ -397,10 +397,10 @@ let test_iota_reduction_2 () = let num n = Expr.elit (LInt (Runtime.integer_of_int n)) nomark in - let injAe e = Expr.einj e consA enumT nomark in - let injBe e = Expr.einj e consB enumT nomark in - let injCe e = Expr.einj e consC enumT nomark in - let injDe e = Expr.einj e consD enumT nomark in + let injAe e = Expr.einj ~e ~cons:consA ~name:enumT nomark in + let injBe e = Expr.einj ~e ~cons:consB ~name:enumT nomark in + let injCe e = Expr.einj ~e ~cons:consC ~name:enumT nomark in + let injDe e = Expr.einj ~e ~cons:consD ~name:enumT nomark in (* let injA x = injAe (Expr.evar x nomark) in *) let injB x = injBe (Expr.evar x nomark) in @@ -409,14 +409,14 @@ let test_iota_reduction_2 () = let matchA = Expr.ematch - (Expr.ematch (num 1) enumT - (cases_of_list + ~e:(Expr.ematch ~e:(num 1) ~name:enumT + ~cases:(cases_of_list [ (consB, fun x -> injBe (injB x)); (consA, fun _x -> injAe (num 20)); ]) nomark) - enumT - (cases_of_list [consA, injC; consB, injD]) + ~name:enumT + ~cases:(cases_of_list [consA, injC; consB, injD]) nomark in Alcotest.(check string) diff --git a/compiler/shared_ast/print.ml b/compiler/shared_ast/print.ml index d199aa26..c2502c69 100644 --- a/compiler/shared_ast/print.ml +++ b/compiler/shared_ast/print.ml @@ -70,14 +70,25 @@ let tlit (fmt : Format.formatter) (l : typ_lit) : unit = | TDuration -> "duration" | TDate -> "date") +let module_name ppf m = Format.fprintf ppf "@{%a@}" ModuleName.format m + +let path ppf p = + Format.pp_print_list ~pp_sep:(fun _ () -> ()) + (fun ppf m -> + Format.fprintf ppf "%a@{.@}" + module_name (Mark.remove m)) + ppf p + let location (type a) (fmt : Format.formatter) (l : a glocation) : unit = match l with - | DesugaredScopeVar (v, _st) -> ScopeVar.format fmt (Mark.remove v) - | ScopelangScopeVar v -> ScopeVar.format fmt (Mark.remove v) - | SubScopeVar (_, subindex, subvar) -> + | DesugaredScopeVar { name; _ } -> ScopeVar.format fmt (Mark.remove name) + | ScopelangScopeVar { name; _ } -> ScopeVar.format fmt (Mark.remove name) + | SubScopeVar { alias=subindex; var=subvar; _ } -> Format.fprintf fmt "%a.%a" SubScopeName.format (Mark.remove subindex) ScopeVar.format (Mark.remove subvar) - | ToplevelVar v -> TopdefName.format fmt (Mark.remove v) + | ToplevelVar { path=p; name } -> + path fmt p; + TopdefName.format fmt (Mark.remove name) let enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit = Format.fprintf fmt "@{%a@}" EnumConstructor.format c @@ -85,6 +96,19 @@ let enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit = let struct_field (fmt : Format.formatter) (c : StructField.t) : unit = Format.fprintf fmt "@{%a@}" StructField.format c +let external_ref fmt er = + match Mark.remove er with + | External_value v -> TopdefName.format fmt v + | External_scope s -> ScopeName.format fmt s + +let rec module_ctx ctx = function + | [] -> ctx + | (modname, mpos) :: path -> + match ModuleName.Map.find_opt modname ctx.ctx_modules with + | None -> + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format modname + | Some ctx -> module_ctx ctx path + let rec typ_gen (ctx : decl_ctx option) ~(colors : Ocolor_types.color4 list) @@ -113,12 +137,16 @@ let rec typ_gen pp_color_string (List.hd colors) fmt ")" | TStruct s -> ( match ctx with - | None -> StructName.format fmt s + | None -> + StructName.format fmt s | Some ctx -> - let fields = StructName.Map.find s ctx.ctx_structs in - if StructField.Map.is_empty fields then StructName.format fmt s + let p, fields = StructName.Map.find s ctx.ctx_structs in + if StructField.Map.is_empty fields then + (path fmt p; StructName.format fmt s) else - Format.fprintf fmt "@[%a %a@,%a@;<0 -2>%a@]" StructName.format s + Format.fprintf fmt "@[%a%a %a@,%a@;<0 -2>%a@]" + path p + StructName.format s (pp_color_string (List.hd colors)) "{" (StructField.Map.format_bindings @@ -137,13 +165,14 @@ let rec typ_gen match ctx with | None -> Format.fprintf fmt "@[%a@]" EnumName.format e | Some ctx -> - Format.fprintf fmt "@[%a%a%a%a@]" EnumName.format e punctuation "[" + let p, def = EnumName.Map.find e ctx.ctx_enums in + Format.fprintf fmt "@[%a%a%a%a%a@]" path p EnumName.format e punctuation "[" (EnumConstructor.Map.format_bindings ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " punctuation "|") (fun fmt pp_case mty -> Format.fprintf fmt "%t%a@ %a" pp_case punctuation ":" (typ ~colors) mty)) - (EnumName.Map.find e ctx.ctx_enums) + def punctuation "]") | TOption t -> Format.fprintf fmt "@[%a@ %a@]" base_type "eoption" (typ ~colors) t @@ -499,7 +528,9 @@ module ExprGen (C : EXPR_PARAM) = struct else match Mark.remove e with | EVar v -> var fmt v - | EExternal eref -> Qident.format fmt eref + | EExternal {path=p; name} -> + path fmt p; + external_ref fmt name | ETuple es -> Format.fprintf fmt "@[%a%a%a@]" (pp_color_string (List.hd colors)) @@ -696,8 +727,9 @@ module ExprGen (C : EXPR_PARAM) = struct Format.fprintf fmt "@[%a %t@ %a@ %a@]" punctuation "|" pp_cons_name punctuation "→" (rhs exprc) e)) cases - | EScopeCall { scope; args } -> + | EScopeCall { path = scope_path; scope; args } -> Format.pp_open_hovbox fmt 2; + path fmt scope_path; ScopeName.format fmt scope; Format.pp_print_space fmt (); keyword fmt "of"; @@ -839,8 +871,8 @@ let enum decl_ctx fmt (pp_name : Format.formatter -> unit) - (c : typ EnumConstructor.Map.t) = - Format.fprintf fmt "@[%a %t %a@ %a@]" keyword "type" pp_name punctuation + (p, c : path * typ EnumConstructor.Map.t) = + Format.fprintf fmt "@[%a %a%t %a@ %a@]" keyword "type" path p pp_name punctuation "=" (EnumConstructor.Map.format_bindings ~pp_sep:(fun _ _ -> ()) @@ -856,9 +888,9 @@ let struct_ decl_ctx fmt (pp_name : Format.formatter -> unit) - (c : typ StructField.Map.t) = - Format.fprintf fmt "@[@[@[%a %t %a@;%a@]@;%a@]%a@]@;" keyword - "type" pp_name punctuation "=" punctuation "{" + (p, c : path * typ StructField.Map.t) = + Format.fprintf fmt "@[@[@[%a %a%t %a@;%a@]@;%a@]%a@]@;" keyword + "type" path p pp_name punctuation "=" punctuation "{" (StructField.Map.format_bindings ~pp_sep:(fun _ _ -> ()) (fun fmt pp_n ty -> diff --git a/compiler/shared_ast/print.mli b/compiler/shared_ast/print.mli index 595e1794..cbab2799 100644 --- a/compiler/shared_ast/print.mli +++ b/compiler/shared_ast/print.mli @@ -42,7 +42,10 @@ val operator_to_string : 'a operator -> string val uid_list : Format.formatter -> Uid.MarkedString.info list -> unit val enum_constructor : Format.formatter -> EnumConstructor.t -> unit val tlit : Format.formatter -> typ_lit -> unit +val module_name : Format.formatter -> ModuleName.t -> unit +val path : Format.formatter -> ModuleName.t Mark.pos list -> unit val location : Format.formatter -> 'a glocation -> unit +val external_ref : Format.formatter -> external_ref Mark.pos -> unit val typ : decl_ctx -> Format.formatter -> typ -> unit val lit : Format.formatter -> lit -> unit val operator : ?debug:bool -> Format.formatter -> 'a operator -> unit diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index c05f6798..f78893c3 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -15,6 +15,7 @@ License for the specific language governing permissions and limitations under the License. *) +open Catala_utils open Definitions let map_exprs ~f ~varf { code_items; decl_ctx } = @@ -34,9 +35,18 @@ let empty_ctx = ctx_structs = StructName.Map.empty; ctx_struct_fields = Ident.Map.empty; ctx_scopes = ScopeName.Map.empty; - ctx_modules = Qident.Map.empty; + ctx_topdefs = TopdefName.Map.empty; + ctx_modules = ModuleName.Map.empty; } +let rec module_ctx ctx = function + | [] -> ctx + | (modname, mpos) :: path -> + match ModuleName.Map.find_opt modname ctx.ctx_modules with + | None -> + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format modname + | Some ctx -> module_ctx ctx path + let get_scope_body { code_items; _ } scope = match Scope.fold_left ~init:None diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index d1c8b704..5b5252cb 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -15,12 +15,16 @@ License for the specific language governing permissions and limitations under the License. *) +open Catala_utils open Definitions (** {2 Program declaration context helpers} *) val empty_ctx : decl_ctx +val module_ctx : decl_ctx -> ModuleName.t Mark.pos list -> decl_ctx +(** Follows a path to get the corresponding context for type and value declarations. Errors out if the module is not found *) + (** {2 Transformations} *) val map_exprs : @@ -47,3 +51,4 @@ val to_expr : ((_ any, _) gexpr as 'e) program -> ScopeName.t -> 'e boxed val equal : (('a any, _) gexpr as 'e) program -> (('a any, _) gexpr as 'e) program -> bool +(** Warning / todo: only compares program scopes at the moment *) diff --git a/compiler/shared_ast/scope.mli b/compiler/shared_ast/scope.mli index 8321f2b9..143b3d54 100644 --- a/compiler/shared_ast/scope.mli +++ b/compiler/shared_ast/scope.mli @@ -1,5 +1,5 @@ (* This file is part of the Catala compiler, a specification language for tax - and social benefits computation rules. Copyright (C) 2020-2022 Inria, +< and social benefits computation rules. Copyright (C) 2020-2022 Inria, contributor: Denis Merigoux , Alain Delaët-Tixeuil , Louis Gesbert diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index 0da04c7a..b2cda58d 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -126,12 +126,12 @@ let rec format_typ (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ") (fun fmt t -> - Format.fprintf fmt "%a" (format_typ ~colors:(List.tl colors)) t)) + format_typ fmt ~colors:(List.tl colors) t)) ts (pp_color_string (List.hd colors)) ")" - | TStruct s -> Format.fprintf fmt "%a" A.StructName.format s - | TEnum e -> Format.fprintf fmt "%a" A.EnumName.format e + | TStruct s -> Print.path fmt (fst (A.StructName.Map.find s ctx.A.ctx_structs)); A.StructName.format fmt s + | TEnum e -> Print.path fmt (fst (A.EnumName.Map.find e ctx.A.ctx_enums)); A.EnumName.format fmt e | TOption t -> Format.fprintf fmt "@[option %a@]" (format_typ_with_parens ~colors:(List.tl colors)) @@ -313,24 +313,26 @@ module Env = struct scope_vars : A.typ A.ScopeVar.Map.t; scopes : A.typ A.ScopeVar.Map.t A.ScopeName.Map.t; toplevel_vars : A.typ A.TopdefName.Map.t; + modules : 'e t A.ModuleName.Map.t; } - let empty (decl_ctx : A.decl_ctx) = + let rec empty (decl_ctx : A.decl_ctx) = (* We fill the environment initially with the structs and enums declarations *) { structs = A.StructName.Map.map - (A.StructField.Map.map ast_to_typ) + (fun (_path, ty) -> A.StructField.Map.map ast_to_typ ty) decl_ctx.ctx_structs; enums = A.EnumName.Map.map - (A.EnumConstructor.Map.map ast_to_typ) + (fun (_path, ty) -> A.EnumConstructor.Map.map ast_to_typ ty) decl_ctx.ctx_enums; vars = Var.Map.empty; scope_vars = A.ScopeVar.Map.empty; scopes = A.ScopeName.Map.empty; toplevel_vars = A.TopdefName.Map.empty; + modules = A.ModuleName.Map.map empty decl_ctx.A.ctx_modules; } let get t v = Var.Map.find_opt v t.vars @@ -341,6 +343,14 @@ module Env = struct Option.bind (A.ScopeName.Map.find_opt scope t.scopes) (fun vmap -> A.ScopeVar.Map.find_opt var vmap) + let rec module_env path env = + match path with + | [] -> env + | (modname, mpos) :: path -> + match A.ModuleName.Map.find_opt modname env.modules with + | None -> Message.raise_spanned_error mpos "Module %a not found" A.ModuleName.format modname + | Some env -> module_env path env + let add v tau t = { t with vars = Var.Map.add v tau t.vars } let add_var v typ t = add v (ast_to_typ typ) t @@ -353,6 +363,10 @@ module Env = struct let add_toplevel_var v typ t = { t with toplevel_vars = A.TopdefName.Map.add v typ t.toplevel_vars } + let add_module modname ~module_env t = + Message.emit_debug "ADD MODULE %a" A.ModuleName.format modname; + { t with modules = A.ModuleName.Map.add modname module_env t.modules } + let open_scope scope_name t = let scope_vars = A.ScopeVar.Map.union @@ -414,11 +428,14 @@ and typecheck_expr_top_down : | A.ELocation loc -> let ty_opt = match loc with - | DesugaredScopeVar (v, _) | ScopelangScopeVar v -> - Env.get_scope_var env (Mark.remove v) - | SubScopeVar (scope, _, v) -> - Env.get_subscope_out_var env scope (Mark.remove v) - | ToplevelVar v -> Env.get_toplevel_var env (Mark.remove v) + | DesugaredScopeVar {name;_} | ScopelangScopeVar {name} -> + Env.get_scope_var env (Mark.remove name) + | SubScopeVar {path; scope; var; _} -> + let env = Env.module_env path env in + Env.get_subscope_out_var env scope (Mark.remove var) + | ToplevelVar {path; name} -> + let env = Env.module_env path env in + Env.get_toplevel_var env (Mark.remove name) in let ty = match ty_opt with @@ -430,8 +447,12 @@ and typecheck_expr_top_down : Expr.elocation loc (mark_with_tau_and_unify (ast_to_typ ty)) | A.EStruct { name; fields } -> let mark = ty_mark (TStruct name) in - let str_ast = A.StructName.Map.find name ctx.A.ctx_structs in - let str = A.StructName.Map.find name env.structs in + let _path, str_ast = + A.StructName.Map.find name ctx.A.ctx_structs + in + let str = + A.StructName.Map.find name env.structs + in let _check_fields : unit = let missing_fields, extra_fields = A.StructField.Map.fold @@ -463,15 +484,15 @@ and typecheck_expr_top_down : "Mismatching field definitions for structure %a" A.StructName.format name in - let fields' = + let fields = A.StructField.Map.mapi (fun f_name f_e -> let f_ty = A.StructField.Map.find f_name str in typecheck_expr_top_down ~leave_unresolved ctx env f_ty f_e) fields in - Expr.estruct name fields' mark - | A.EDStructAccess { e = e_struct; name_opt; field } -> + Expr.estruct ~name ~fields mark + | A.EDStructAccess { e = e_struct; path = _; name_opt; field } -> let t_struct = match name_opt with | Some name -> TStruct name @@ -492,6 +513,7 @@ and typecheck_expr_top_down : "This is not a structure, cannot access field %s (%a)" field (format_typ ctx) (ty e_struct') in + let path, _ = A.StructName.Map.find name ctx.ctx_structs in let fld_ty = let str = try A.StructName.Map.find name env.structs @@ -526,7 +548,7 @@ and typecheck_expr_top_down : A.StructField.Map.find field str in let mark = mark_with_tau_and_unify fld_ty in - Expr.edstructaccess e_struct' field (Some name) mark + Expr.edstructaccess ~e:e_struct' ~path ~name_opt:(Some name) ~field mark | A.EStructAccess { e = e_struct; name; field } -> let fld_ty = let str = @@ -551,7 +573,7 @@ and typecheck_expr_top_down : typecheck_expr_top_down ~leave_unresolved ctx env (unionfind (TStruct name)) e_struct in - Expr.estructaccess e_struct' field name mark + Expr.estructaccess ~e:e_struct' ~field ~name mark | A.EInj { name; cons; e = e_enum } when Definitions.EnumName.equal name Expr.option_enum -> if Definitions.EnumConstructor.equal cons Expr.some_constr then @@ -560,7 +582,7 @@ and typecheck_expr_top_down : let e_enum' = typecheck_expr_top_down ~leave_unresolved ctx env cell_type e_enum in - Expr.einj e_enum' cons name mark + Expr.einj ~name ~cons ~e:e_enum' mark else (* None constructor *) let cell_type = unionfind (TAny (Any.fresh ())) in @@ -569,7 +591,7 @@ and typecheck_expr_top_down : typecheck_expr_top_down ~leave_unresolved ctx env (unionfind (TLit TUnit)) e_enum in - Expr.einj e_enum' cons name mark + Expr.einj ~name ~cons ~e:e_enum' mark | A.EInj { name; cons; e = e_enum } -> let mark = mark_with_tau_and_unify (unionfind (TEnum name)) in let e_enum' = @@ -577,7 +599,7 @@ and typecheck_expr_top_down : (A.EnumConstructor.Map.find cons (A.EnumName.Map.find name env.enums)) e_enum in - Expr.einj e_enum' cons name mark + Expr.einj ~e:e_enum' ~cons ~name mark | A.EMatch { e = e1; name; cases } when Definitions.EnumName.equal name Expr.option_enum -> let cell_type = unionfind ~pos:e1 (TAny (Any.fresh ())) in @@ -591,7 +613,7 @@ and typecheck_expr_top_down : let t_ret = unionfind ~pos:e (TAny (Any.fresh ())) in let mark = mark_with_tau_and_unify t_ret in let e1' = typecheck_expr_top_down ~leave_unresolved ctx env t_arg e1 in - let cases' = + let cases = A.EnumConstructor.Map.merge (fun _ e e_ty -> match e, e_ty with @@ -603,17 +625,19 @@ and typecheck_expr_top_down : | _ -> assert false) cases cases_ty in - - Expr.ematch e1' name cases' mark + Expr.ematch ~e:e1' ~name ~cases mark | A.EMatch { e = e1; name; cases } -> - let cases_ty = A.EnumName.Map.find name ctx.A.ctx_enums in + let _path, cases_ty = + A.EnumName.Map.find name ctx.A.ctx_enums + in let t_ret = unionfind ~pos:e1 (TAny (Any.fresh ())) in let mark = mark_with_tau_and_unify t_ret in let e1' = - typecheck_expr_top_down ~leave_unresolved ctx env (unionfind (TEnum name)) + typecheck_expr_top_down ~leave_unresolved ctx env + (unionfind (TEnum name)) e1 in - let cases' = + let cases = A.EnumConstructor.Map.mapi (fun c_name e -> let c_ty = A.EnumConstructor.Map.find c_name cases_ty in @@ -624,13 +648,17 @@ and typecheck_expr_top_down : typecheck_expr_top_down ~leave_unresolved ctx env e_ty e) cases in - Expr.ematch e1' name cases' mark - | A.EScopeCall { scope; args } -> + Expr.ematch ~e:e1' ~name ~cases mark + | A.EScopeCall { path; scope; args } -> let scope_out_struct = + let ctx = Program.module_ctx ctx path in (A.ScopeName.Map.find scope ctx.ctx_scopes).out_struct_name in let mark = mark_with_tau_and_unify (unionfind (TStruct scope_out_struct)) in - let vars = A.ScopeName.Map.find scope env.scopes in + let vars = + let env = Env.module_env path env in + A.ScopeName.Map.find scope env.scopes + in let args' = A.ScopeVar.Map.mapi (fun name -> @@ -638,7 +666,7 @@ and typecheck_expr_top_down : (ast_to_typ (A.ScopeVar.Map.find name vars))) args in - Expr.escopecall scope args' mark + Expr.escopecall ~path ~scope ~args:args' mark | A.ERaise ex -> Expr.eraise ex context_mark | A.ECatch { body; exn; handler } -> let body' = typecheck_expr_top_down ~leave_unresolved ctx env tau body in @@ -655,16 +683,30 @@ and typecheck_expr_top_down : "Variable %s not found in the current context" (Bindlib.name_of v) in Expr.evar (Var.translate v) (mark_with_tau_and_unify tau') - | A.EExternal eref -> + | A.EExternal {path; name} -> + let ctx = Program.module_ctx ctx path in let ty = - try Qident.Map.find eref ctx.ctx_modules - with Not_found -> + let not_found pr x = Message.raise_spanned_error pos_e - "Could not resolve the reference to %a.@ Make sure the corresponding \ + "Could not resolve the reference to %a%a.@ Make sure the corresponding \ module was properly loaded?" - Qident.format eref + Print.path path + pr x + in + match Mark.remove name with + | A.External_value name -> + (try + ast_to_typ (A.TopdefName.Map.find name ctx.ctx_topdefs) + with Not_found -> not_found A.TopdefName.format name) + | A.External_scope name -> + (try + let scope_info = A.ScopeName.Map.find name ctx.ctx_scopes in + ast_to_typ (TArrow ([TStruct scope_info.in_struct_name, pos_e], + (TStruct scope_info.out_struct_name, pos_e)), + pos_e) + with Not_found -> not_found A.ScopeName.format name) in - Expr.eexternal eref (mark_with_tau_and_unify (ast_to_typ ty)) + Expr.eexternal ~path ~name (mark_with_tau_and_unify ty) | A.ELit lit -> Expr.elit lit (ty_mark (lit_type lit)) | A.ETuple es -> let tys = List.map (fun _ -> unionfind (TAny (Any.fresh ()))) es in @@ -990,7 +1032,8 @@ let program ~leave_unresolved prg = prg.decl_ctx with ctx_structs = A.StructName.Map.mapi - (fun s_name fields -> + (fun s_name (path, fields) -> + path, A.StructField.Map.mapi (fun f_name (t : A.typ) -> match Mark.remove t with @@ -1003,7 +1046,8 @@ let program ~leave_unresolved prg = prg.decl_ctx.ctx_structs; ctx_enums = A.EnumName.Map.mapi - (fun e_name cons -> + (fun e_name (path, cons) -> + path, A.EnumConstructor.Map.mapi (fun cons_name (t : A.typ) -> match Mark.remove t with diff --git a/compiler/shared_ast/typing.mli b/compiler/shared_ast/typing.mli index c295b34d..7de65aad 100644 --- a/compiler/shared_ast/typing.mli +++ b/compiler/shared_ast/typing.mli @@ -27,6 +27,8 @@ module Env : sig val add_toplevel_var : TopdefName.t -> typ -> 'e t -> 'e t val add_scope_var : ScopeVar.t -> typ -> 'e t -> 'e t val add_scope : ScopeName.t -> vars:typ ScopeVar.Map.t -> 'e t -> 'e t + val add_module : ModuleName.t -> module_env:'e t -> 'e t -> 'e t + val module_env : path -> 'e t -> 'e t val open_scope : ScopeName.t -> 'e t -> 'e t end diff --git a/compiler/surface/ast.ml b/compiler/surface/ast.ml index ea4b27d9..e1f5b890 100644 --- a/compiler/surface/ast.ml +++ b/compiler/surface/ast.ml @@ -251,7 +251,7 @@ and scope_decl_context_io = { and scope_decl_context_scope = { scope_decl_context_scope_name : lident Mark.pos; - scope_decl_context_scope_sub_scope : uident Mark.pos; + scope_decl_context_scope_sub_scope : (path * uident Mark.pos) Mark.pos; scope_decl_context_scope_attribute : scope_decl_context_io; } @@ -309,11 +309,13 @@ and law_structure = | LawText of (string[@opaque]) | CodeBlock of code_block * source_repr * bool (* Metadata if true *) +and interface = code_block +(** Invariant: an interface shall only contain [*Decl] elements, or [Topdef] elements with [topdef_expr = None] *) + and program = { - program_interfaces : - ((Shared_ast.Qident.path[@opaque]) * code_item Mark.pos) list; program_items : law_structure list; program_source_files : (string[@opaque]) list; + program_modules : (uident * interface) list; } and source_file = law_structure list diff --git a/compiler/surface/parser.messages b/compiler/surface/parser.messages index f5a12ad9..03ef43c8 100644 --- a/compiler/surface/parser.messages +++ b/compiler/surface/parser.messages @@ -3957,45 +3957,11 @@ source_file: BEGIN_CODE DECLARATION SCOPE UIDENT COLON CONTEXT LIDENT CONDITION expected the next definition in scope -source_file: BEGIN_CODE DECLARATION SCOPE UIDENT COLON CONTEXT LIDENT CONDITION DEPENDS LIDENT CONTENT UIDENT DEFINED_AS -## -## Ends in an error in state: 344. -## -## scope_decl_item -> scope_decl_item_attribute lident CONDITION DEPENDS separated_nonempty_list(COMMA,var_content) . list(state) [ SCOPE OUTPUT LIDENT INTERNAL INPUT END_CODE DECLARATION CONTEXT ] -## -## The known suffix of the stack is as follows: -## scope_decl_item_attribute lident CONDITION DEPENDS separated_nonempty_list(COMMA,var_content) -## -## WARNING: This example involves spurious reductions. -## This implies that, although the LR(1) items shown above provide an -## accurate view of the past (what has been recognized so far), they -## may provide an INCOMPLETE view of the future (what was expected next). -## In state 21, spurious reduction of production quident -> UIDENT -## In state 30, spurious reduction of production primitive_typ -> quident -## In state 296, spurious reduction of production typ_data -> primitive_typ -## In state 307, spurious reduction of production separated_nonempty_list(COMMA,var_content) -> lident CONTENT typ_data -## - -expected the next definition in scope, or a comma followed by another argument declaration (', content ') - -source_file: BEGIN_CODE DECLARATION SCOPE UIDENT COLON LIDENT SCOPE UIDENT YEAR -## -## Ends in an error in state: 347. -## -## nonempty_list(addpos(scope_decl_item)) -> scope_decl_item . [ SCOPE END_CODE DECLARATION ] -## nonempty_list(addpos(scope_decl_item)) -> scope_decl_item . nonempty_list(addpos(scope_decl_item)) [ SCOPE END_CODE DECLARATION ] -## -## The known suffix of the stack is as follows: -## scope_decl_item -## - -expected the next declaration for the scope - source_file: BEGIN_CODE DECLARATION SCOPE UIDENT COLON LIDENT YEAR ## ## Ends in an error in state: 349. ## -## scope_decl_item -> lident . SCOPE UIDENT [ SCOPE OUTPUT LIDENT INTERNAL INPUT END_CODE DECLARATION CONTEXT ] +## scope_decl_item -> lident . SCOPE quident [ SCOPE OUTPUT LIDENT INTERNAL INPUT END_CODE DECLARATION CONTEXT ] ## ## The known suffix of the stack is as follows: ## lident @@ -4007,7 +3973,7 @@ source_file: BEGIN_CODE DECLARATION SCOPE UIDENT COLON LIDENT SCOPE YEAR ## ## Ends in an error in state: 350. ## -## scope_decl_item -> lident SCOPE . UIDENT [ SCOPE OUTPUT LIDENT INTERNAL INPUT END_CODE DECLARATION CONTEXT ] +## scope_decl_item -> lident SCOPE . quident [ SCOPE OUTPUT LIDENT INTERNAL INPUT END_CODE DECLARATION CONTEXT ] ## ## The known suffix of the stack is as follows: ## lident SCOPE diff --git a/compiler/surface/parser.mly b/compiler/surface/parser.mly index c5b66851..1c23f443 100644 --- a/compiler/surface/parser.mly +++ b/compiler/surface/parser.mly @@ -574,7 +574,7 @@ let scope_decl_item := scope_decl_context_item_states = states; } } -| i = lident ; SCOPE ; c = uident ; { +| i = lident ; SCOPE ; c = addpos(quident) ; { ContextScope{ scope_decl_context_scope_name = i; scope_decl_context_scope_sub_scope = c; diff --git a/compiler/surface/parser_driver.ml b/compiler/surface/parser_driver.ml index db8bb25a..055979e6 100644 --- a/compiler/surface/parser_driver.ml +++ b/compiler/surface/parser_driver.ml @@ -229,9 +229,9 @@ let rec parse_source_file (match input with Some input -> close_in input | None -> ()); let program = expand_includes source_file_name commands language in { - program_interfaces = []; program_items = program.Ast.program_items; program_source_files = source_file_name :: program.Ast.program_source_files; + program_modules = [] } (** Expands the include directives in a parsing result, thus parsing new source @@ -248,31 +248,33 @@ and expand_includes let sub_source = File.(source_dir / Mark.remove sub_source) in let includ_program = parse_source_file (FileName sub_source) language in { - program_interfaces = []; Ast.program_source_files = acc.Ast.program_source_files @ includ_program.program_source_files; Ast.program_items = acc.Ast.program_items @ includ_program.program_items; + Ast.program_modules = + acc.Ast.program_modules @ includ_program.program_modules; } | Ast.LawHeading (heading, commands') -> let { - Ast.program_interfaces = _; Ast.program_items = commands'; Ast.program_source_files = new_sources; + Ast.program_modules = new_modules; } = expand_includes source_file commands' language in { - Ast.program_interfaces = []; Ast.program_source_files = acc.Ast.program_source_files @ new_sources; Ast.program_items = acc.Ast.program_items @ [Ast.LawHeading (heading, commands')]; + Ast.program_modules = + acc.Ast.program_modules @ new_modules; } | i -> { acc with Ast.program_items = acc.Ast.program_items @ [i] }) { - Ast.program_interfaces = []; Ast.program_source_files = []; Ast.program_items = []; + Ast.program_modules = []; } commands @@ -297,30 +299,17 @@ let get_interface program = in List.fold_left filter [] program.Ast.program_items -let qualify_interface path code_items = - List.map (fun item -> path, item) code_items - (** {1 API} *) -let add_interface source_file language path program = - let interface = - parse_source_file source_file language - |> get_interface - |> qualify_interface path - in - { - program with - Ast.program_interfaces = - List.append interface program.Ast.program_interfaces; - } +let load_interface source_file language = + parse_source_file source_file language + |> get_interface let parse_top_level_file (source_file : Cli.input_file) (language : Cli.backend_lang) : Ast.program = let program = parse_source_file source_file language in - let interface = get_interface program in { program with Ast.program_items = law_struct_list_to_tree program.Ast.program_items; - Ast.program_interfaces = qualify_interface [] interface; } diff --git a/compiler/surface/parser_driver.mli b/compiler/surface/parser_driver.mli index f608e7bd..0819446f 100644 --- a/compiler/surface/parser_driver.mli +++ b/compiler/surface/parser_driver.mli @@ -19,13 +19,11 @@ open Catala_utils -val add_interface : +val load_interface : Cli.input_file -> Cli.backend_lang -> - Shared_ast.Qident.path -> - Ast.program -> - Ast.program -(** Reads only declarations in metadata in the supplied input file, and add them - to the given program *) + Ast.interface +(** Reads only declarations in metadata in the supplied input file, and only keeps type information *) val parse_top_level_file : Cli.input_file -> Cli.backend_lang -> Ast.program +(** Parses a catala file (handling file includes) and returns a program. Modules in the program are returned empty, use [load_interface] to fill them. *) diff --git a/compiler/verification/z3backend.real.ml b/compiler/verification/z3backend.real.ml index 3a95e427..329128fd 100644 --- a/compiler/verification/z3backend.real.ml +++ b/compiler/verification/z3backend.real.ml @@ -162,7 +162,7 @@ let rec print_z3model_expr (ctx : context) (ty : typ) (e : Expr.expr) : string = match Mark.remove ty with | TLit ty -> print_lit ty | TStruct name -> - let s = StructName.Map.find name ctx.ctx_decl.ctx_structs in + let _path, s = StructName.Map.find name ctx.ctx_decl.ctx_structs in let get_fieldname (fn : StructField.t) : string = Mark.remove (StructField.get_info fn) in @@ -188,7 +188,7 @@ let rec print_z3model_expr (ctx : context) (ty : typ) (e : Expr.expr) : string = let fd = Expr.get_func_decl e in let fd_name = Symbol.to_string (FuncDecl.get_name fd) in - let enum_ctrs = EnumName.Map.find name ctx.ctx_decl.ctx_enums in + let _path, enum_ctrs = EnumName.Map.find name ctx.ctx_decl.ctx_enums in let case = List.find (fun (ctr, _) -> @@ -315,7 +315,7 @@ and find_or_create_enum (ctx : context) (enum : EnumName.t) : match EnumName.Map.find_opt enum ctx.ctx_z3datatypes with | Some e -> ctx, e | None -> - let ctrs = EnumName.Map.find enum ctx.ctx_decl.ctx_enums in + let _path, ctrs = EnumName.Map.find enum ctx.ctx_decl.ctx_enums in let ctx, z3_ctrs = EnumConstructor.Map.fold (fun ctr ty (ctx, ctrs) -> @@ -340,7 +340,7 @@ and find_or_create_struct (ctx : context) (s : StructName.t) : | Some s -> ctx, s | None -> let s_name = Mark.remove (StructName.get_info s) in - let fields = StructName.Map.find s ctx.ctx_decl.ctx_structs in + let _path, fields = StructName.Map.find s ctx.ctx_decl.ctx_structs in let z3_fieldnames = List.map (fun f -> @@ -666,10 +666,10 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr = mk_struct. The accessors of this constructor correspond to the field accesses *) let accessors = List.hd (Datatype.get_accessors z3_struct) in + let _path, fields = StructName.Map.find name ctx.ctx_decl.ctx_structs in let idx_mappings = List.combine - (StructField.Map.keys - (StructName.Map.find name ctx.ctx_decl.ctx_structs)) + (StructField.Map.keys fields) accessors in let _, accessor = @@ -685,11 +685,11 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr = let ctx, z3_enum = find_or_create_enum ctx name in let ctx, z3_arg = translate_expr ctx e in let ctrs = Datatype.get_constructors z3_enum in + let _path, cons_map = EnumName.Map.find name ctx.ctx_decl.ctx_enums in (* This should always succeed if the expression is well-typed in dcalc *) let idx_mappings = List.combine - (EnumConstructor.Map.keys - (EnumName.Map.find name ctx.ctx_decl.ctx_enums)) + (EnumConstructor.Map.keys cons_map) ctrs in let _, ctr = diff --git a/doc/syntax/syntax_en.tex b/doc/syntax/syntax_en.tex index 3297a0f7..80d0e805 100644 --- a/doc/syntax/syntax_en.tex +++ b/doc/syntax/syntax_en.tex @@ -10,8 +10,6 @@ %\usepackage{booktabs} \usepackage{upquote} % Uncurly the quotes \usepackage{etoolbox} % for backquote fix -\usepackage[scaled=0.9]{DejaVuSans} -\usepackage[scaled=0.9]{DejaVuSansMono} \usepackage{mdframed} % nice frames \usepackage[nobottomtitles]{titlesec} % better titles \usepackage{enumitem} diff --git a/doc/syntax/syntax_fr.tex b/doc/syntax/syntax_fr.tex index 886b6738..85df32c9 100644 --- a/doc/syntax/syntax_fr.tex +++ b/doc/syntax/syntax_fr.tex @@ -10,8 +10,6 @@ %\usepackage{booktabs} \usepackage{upquote} % Uncurly the quotes \usepackage{etoolbox} % for backquote fix -\usepackage[scaled=0.9]{DejaVuSans} -\usepackage[scaled=0.9]{DejaVuSansMono} \usepackage{mdframed} % nice frames \usepackage[nobottomtitles]{titlesec} % better titles \usepackage{enumitem} diff --git a/tests/test_func/good/closure_conversion.catala_en b/tests/test_func/good/closure_conversion.catala_en index 9b4cbb26..69a021d1 100644 --- a/tests/test_func/good/closure_conversion.catala_en +++ b/tests/test_func/good/closure_conversion.catala_en @@ -15,9 +15,9 @@ scope S: $ catala Lcalc --avoid_exceptions -O --closure_conversion type eoption = | ENone of unit | ESome of any -type S = { z: eoption integer; } - type S_in = { x_in: eoption bool; } + +type S = { z: eoption integer; } let topval closure_f : (closure_env, integer) → eoption integer = λ (env: closure_env) (y: integer) → diff --git a/tests/test_func/good/closure_return.catala_en b/tests/test_func/good/closure_return.catala_en index fea06f77..4e26c81c 100644 --- a/tests/test_func/good/closure_return.catala_en +++ b/tests/test_func/good/closure_return.catala_en @@ -13,11 +13,11 @@ scope S: $ catala Lcalc --avoid_exceptions -O --closure_conversion type eoption = | ENone of unit | ESome of any +type S_in = { x_in: eoption bool; } + type S = { f: eoption ((closure_env, integer) → eoption integer * closure_env); } - -type S_in = { x_in: eoption bool; } let topval closure_f : (closure_env, integer) → eoption integer = λ (env: closure_env) (y: integer) → diff --git a/tests/test_func/good/scope_call_func_struct_closure.catala_en b/tests/test_func/good/scope_call_func_struct_closure.catala_en index cefbbbf6..26c8fe89 100644 --- a/tests/test_func/good/scope_call_func_struct_closure.catala_en +++ b/tests/test_func/good/scope_call_func_struct_closure.catala_en @@ -49,23 +49,23 @@ type Result = { q: eoption integer; } +type SubFoo1_in = { x_in: eoption integer; } + type SubFoo1 = { x: eoption integer; y: eoption ((closure_env, integer) → eoption integer * closure_env); } +type SubFoo2_in = { x1_in: eoption integer; x2_in: eoption integer; } + type SubFoo2 = { x1: eoption integer; y: eoption ((closure_env, integer) → eoption integer * closure_env); } -type Foo = { z: eoption integer; } - -type SubFoo1_in = { x_in: eoption integer; } - -type SubFoo2_in = { x1_in: eoption integer; x2_in: eoption integer; } - type Foo_in = { b_in: eoption bool; } + +type Foo = { z: eoption integer; } let topval closure_y : (closure_env, integer) → eoption integer = λ (env: closure_env) (z: integer) →