diff --git a/compiler/catala_utils/uid.ml b/compiler/catala_utils/uid.ml index 81e35070..451b4db6 100644 --- a/compiler/catala_utils/uid.ml +++ b/compiler/catala_utils/uid.ml @@ -81,14 +81,14 @@ module Gen () = Make (MarkedString) () module Module = struct include String + let to_string m = m let format ppf m = Format.fprintf ppf "@{%s@}" m - let of_string m = m end (* TODO: should probably be turned into an uid once we implement module import - directives; that will incur an additional resolution work on all paths - though ([module Module = Gen ()]) *) + directives; that will incur an additional resolution work on all paths though + ([module Module = Gen ()]) *) module Path = struct type t = Module.t list @@ -96,8 +96,7 @@ module Path = struct let format ppf p = Format.pp_print_list ~pp_sep:(fun _ () -> ()) - (fun ppf m -> - Format.fprintf ppf "%a@{.@}" Module.format m) + (fun ppf m -> Format.fprintf ppf "%a@{.@}" Module.format m) ppf p let to_string p = String.concat "." p @@ -110,14 +109,15 @@ module QualifiedMarkedString = struct let to_string (p, i) = Format.asprintf "%a%a" Path.format p MarkedString.format i + let format fmt (p, i) = - Path.format fmt p; MarkedString.format fmt i - let equal (p1, i1) (p2, i2) = - Path.equal p1 p2 && MarkedString.equal i1 i2 + Path.format fmt p; + MarkedString.format fmt i + + let equal (p1, i1) (p2, i2) = Path.equal p1 p2 && MarkedString.equal i1 i2 + let compare (p1, i1) (p2, i2) = - match Path.compare p1 p2 with - | 0 -> MarkedString.compare i1 i2 - | n -> n + match Path.compare p1 p2 with 0 -> MarkedString.compare i1 i2 | n -> n end module Gen_qualified () = struct diff --git a/compiler/catala_utils/uid.mli b/compiler/catala_utils/uid.mli index 196bb96b..ef59e552 100644 --- a/compiler/catala_utils/uid.mli +++ b/compiler/catala_utils/uid.mli @@ -63,32 +63,33 @@ module Gen () : Id with type info = MarkedString.info (** {2 Handling of Uids with additional path information} *) -module Module: sig +module Module : sig type t = private string (* TODO: this will become an uid at some point *) - val to_string: t -> string - val format: Format.formatter -> t -> unit - val equal : t -> t-> bool + + val to_string : t -> string + val format : Format.formatter -> t -> unit + val equal : t -> t -> bool val compare : t -> t -> int - val of_string: string -> t + val of_string : string -> t module Set : Set.S with type elt = t module Map : Map.S with type key = t end -module Path: sig +module Path : sig type t = Module.t list - val to_string: t -> string - val format: Format.formatter -> t -> unit - val equal : t -> t-> bool + val to_string : t -> string + val format : Format.formatter -> t -> unit + val equal : t -> t -> bool val compare : t -> t -> int end (** Same as [Gen] but also registers path information *) module Gen_qualified () : sig include Id with type info = Path.t * MarkedString.info - val fresh : Path.t -> MarkedString.info -> t + val fresh : Path.t -> MarkedString.info -> t val path : t -> Path.t val get_info : t -> MarkedString.info end diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index b4ed37ff..661340c9 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -75,7 +75,8 @@ let pos_mark_mk (type a m) (e : (a, m) gexpr) : let module_scope_sig scope_sig_ctx scope = let ssctx = - List.fold_left (fun ssctx m -> ModuleName.Map.find m ssctx.scope_sigs_modules) + List.fold_left + (fun ssctx m -> ModuleName.Map.find m ssctx.scope_sigs_modules) scope_sig_ctx (ScopeName.path scope) in ScopeName.Map.find scope ssctx.scope_sigs @@ -222,8 +223,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : Message.raise_spanned_error (Expr.pos e) "The constructor %a of enum %a is missing from this pattern \ matching" - EnumConstructor.format constructor - EnumName.format name + EnumConstructor.format constructor EnumName.format name in let case_d = translate_expr ctx case_e in ( EnumConstructor.Map.add constructor case_d d_cases, @@ -298,9 +298,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : match sc_sig.scope_sig_scope_ref with | Local_scope_ref v -> Expr.evar v m | External_scope_ref name -> - Expr.eexternal - ~name:(Mark.map (fun s -> External_scope s) name) - m + Expr.eexternal ~name:(Mark.map (fun s -> External_scope s) name) m in tag_with_log_entry e BeginCall [ScopeName.get_info scope; Mark.add (Expr.pos e) "direct"] @@ -409,7 +407,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : ts_in (Expr.pos e) | _ -> original_field_expr) (StructName.Map.find sc_sig.scope_sig_output_struct - ctx.decl_ctx.ctx_structs)) + ctx.decl_ctx.ctx_structs)) (Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e)) in (* Here we have to go through an if statement that records a decision being @@ -494,7 +492,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : |> SubScopeName.Map.find (Mark.remove alias) |> retrieve_in_and_out_typ_or_any var | ELocation (ToplevelVar { name }) -> ( - let decl_ctx = Program.module_ctx ctx.decl_ctx (TopdefName.path (Mark.remove name)) in + let decl_ctx = + Program.module_ctx ctx.decl_ctx (TopdefName.path (Mark.remove name)) + in let typ = TopdefName.Map.find (Mark.remove name) decl_ctx.ctx_topdefs in match Mark.remove typ with | TArrow (tin, (tout, _)) -> List.map Mark.remove tin, tout @@ -573,8 +573,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : if path = [] then let v, _ = TopdefName.Map.find (Mark.remove name) ctx.toplevel_vars in Expr.evar v m - else - Expr.eexternal ~name:(Mark.map (fun n -> External_value n) name) m + else Expr.eexternal ~name:(Mark.map (fun n -> External_value n) name) m | EOp { op = Add_dat_dur _; tys } -> Expr.eop (Add_dat_dur ctx.date_rounding) tys m | EOp { op; tys } -> Expr.eop (Operator.translate op) tys m @@ -1087,9 +1086,8 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = prgm.program_modules; let sctx : 'm scope_sigs_ctx = let process_scope_sig scope_name scope = - Message.emit_debug "process_scope_sig %a (%a)" - ScopeName.format scope_name ScopeName.format - scope.Scopelang.Ast.scope_decl_name; + Message.emit_debug "process_scope_sig %a (%a)" ScopeName.format scope_name + ScopeName.format scope.Scopelang.Ast.scope_decl_name; let scope_path = ScopeName.path scope_name in let scope_ref = if scope_path = [] then @@ -1106,8 +1104,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = with ScopeName.Map.Not_found _ -> Message.raise_spanned_error (Mark.get (ScopeName.get_info scope_name)) - "Could not find scope %a" ScopeName.format - scope_name + "Could not find scope %a" ScopeName.format scope_name in let scope_sig_in_fields = (* Output fields have already been generated and added to the program @@ -1154,8 +1151,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = process_scope_sig scope_name scope_decl) prg.Scopelang.Ast.program_scopes; scope_sigs_modules = - ModuleName.Map.map process_modules - prg.Scopelang.Ast.program_modules; + ModuleName.Map.map process_modules prg.Scopelang.Ast.program_modules; } in { @@ -1165,18 +1161,14 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = process_scope_sig scope_name scope_decl) prgm.Scopelang.Ast.program_scopes; scope_sigs_modules = - ModuleName.Map.map - process_modules - prgm.Scopelang.Ast.program_modules; + ModuleName.Map.map process_modules prgm.Scopelang.Ast.program_modules; } in let rec gather_module_in_structs acc sctx = (* Expose all added in_structs from submodules at toplevel *) ModuleName.Map.fold (fun _ scope_sigs acc -> - let acc = - gather_module_in_structs acc scope_sigs.scope_sigs_modules - in + let acc = gather_module_in_structs acc scope_sigs.scope_sigs_modules in ScopeName.Map.fold (fun _ scope_sig_ctx acc -> let fields = @@ -1190,8 +1182,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = acc) scope_sig_ctx.scope_sig_in_fields StructField.Map.empty in - StructName.Map.add scope_sig_ctx.scope_sig_input_struct - fields acc) + StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc) scope_sigs.scope_sigs acc) sctx acc in diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index 22f8d558..4c12d8fc 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -373,8 +373,7 @@ let rec translate_expr match Ident.Map.find_opt x ctxt.topdefs with | Some v -> Expr.elocation - (ToplevelVar - { name = v, Mark.get (TopdefName.get_info v) }) + (ToplevelVar { name = v, Mark.get (TopdefName.get_info v) }) emark | None -> Name_resolution.raise_unknown_identifier @@ -1392,10 +1391,13 @@ let init_scope_defs in scope_def) | Name_resolution.SubScope (v0, subscope_uid) -> - let sub_scope_def = - Name_resolution.get_scope_context ctxt subscope_uid + let sub_scope_def = Name_resolution.get_scope_context ctxt subscope_uid in + let ctxt = + List.fold_left + (fun ctx m -> ModuleName.Map.find m ctx.Name_resolution.modules) + ctxt + (ScopeName.path subscope_uid) in - let ctxt = List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.Name_resolution.modules) ctxt (ScopeName.path subscope_uid) in Ident.Map.fold (fun _ v scope_def_map -> match v with @@ -1475,19 +1477,15 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : (fun _ prg acc -> StructName.Map.union (fun _ _ _ -> assert false) - acc - prg.Ast.program_ctx.ctx_structs) - submodules - ctxt.Name_resolution.structs; + acc prg.Ast.program_ctx.ctx_structs) + submodules ctxt.Name_resolution.structs; ctx_enums = ModuleName.Map.fold (fun _ prg acc -> EnumName.Map.union (fun _ _ _ -> assert false) - acc - prg.Ast.program_ctx.ctx_enums) - submodules - ctxt.Name_resolution.enums; + acc prg.Ast.program_ctx.ctx_enums) + submodules ctxt.Name_resolution.enums; ctx_scopes = Ident.Map.fold (fun _ def acc -> @@ -1540,9 +1538,7 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : let id = ModuleName.of_string id in let modul = ModuleName.Map.find id acc.Ast.program_modules in let modul = - process_code_block - (ModuleName.Map.find id ctxt.modules) - modul intf + process_code_block (ModuleName.Map.find id ctxt.modules) modul intf in { acc with diff --git a/compiler/desugared/linting.ml b/compiler/desugared/linting.ml index e45b717a..5258f124 100644 --- a/compiler/desugared/linting.ml +++ b/compiler/desugared/linting.ml @@ -108,8 +108,7 @@ let detect_unused_struct_fields (p : program) : unit = ~f:(fun struct_fields_used e -> let rec structs_fields_used_expr e struct_fields_used = match Mark.remove e with - | EDStructAccess - { name_opt = Some name; e = e_struct; field } -> + | EDStructAccess { name_opt = Some name; e = e_struct; field } -> let field = StructName.Map.find name (Ident.Map.find field p.program_ctx.ctx_struct_fields) diff --git a/compiler/desugared/name_resolution.ml b/compiler/desugared/name_resolution.ml index 43c882e0..bf71a951 100644 --- a/compiler/desugared/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -113,9 +113,11 @@ let get_var_io (ctxt : context) (uid : ScopeVar.t) : Surface.Ast.scope_decl_context_io = (ScopeVar.Map.find uid ctxt.var_typs).var_sig_io -let get_scope_context (ctxt: context) (scope: ScopeName.t) : scope_context = - let rec remove_common_prefix curpath scpath = match curpath, scpath with - | m1 :: cp, m2 :: sp when ModuleName.equal m1 m2 -> remove_common_prefix cp sp +let get_scope_context (ctxt : context) (scope : ScopeName.t) : scope_context = + let rec remove_common_prefix curpath scpath = + match curpath, scpath with + | m1 :: cp, m2 :: sp when ModuleName.equal m1 m2 -> + remove_common_prefix cp sp | _ -> scpath in let path = remove_common_prefix ctxt.path (ScopeName.path scope) in @@ -776,8 +778,7 @@ let get_def_key ScopeVar.format x_uid else None ) | [y; x] -> - let (subscope_uid, subscope_real_uid) - : SubScopeName.t * ScopeName.t = + let (subscope_uid, subscope_real_uid) : SubScopeName.t * ScopeName.t = match Ident.Map.find_opt (Mark.remove y) scope_ctxt.var_idmap with | Some (SubScope (v, u)) -> v, u | Some _ -> @@ -788,9 +789,7 @@ let get_def_key Message.raise_spanned_error pos "No definition found for subscope %a" Print.lit_style (Mark.remove y) in - let x_uid = - get_var_uid subscope_real_uid ctxt x - in + let x_uid = get_var_uid subscope_real_uid ctxt x in Ast.ScopeDef.SubScopeVar (subscope_uid, x_uid, pos) | _ -> Message.raise_spanned_error pos diff --git a/compiler/desugared/name_resolution.mli b/compiler/desugared/name_resolution.mli index fa2c0206..86cc6c21 100644 --- a/compiler/desugared/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -69,7 +69,7 @@ type typedef = type context = { path : ModuleName.t list; - (** The current path being processed. Used for generating the Uids. *) + (** The current path being processed. Used for generating the Uids. *) typedefs : typedef Ident.Map.t; (** Gathers the names of the scopes, structs and enums *) field_idmap : StructField.t StructName.Map.t Ident.Map.t; @@ -108,7 +108,8 @@ val is_var_cond : context -> ScopeVar.t -> bool val get_var_io : context -> ScopeVar.t -> Surface.Ast.scope_decl_context_io val get_scope_context : context -> ScopeName.t -> scope_context -(** Get the corresponding scope context from the context, looking up into nested submodules as necessary, following the path information in the scope name *) +(** Get the corresponding scope context from the context, looking up into nested + submodules as necessary, following the path information in the scope name *) val get_var_uid : ScopeName.t -> context -> Ident.t Mark.pos -> ScopeVar.t (** Get the variable uid inside the scope given in argument *) diff --git a/compiler/driver.ml b/compiler/driver.ml index 15ac2186..ad6ee289 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -254,8 +254,8 @@ module Commands = struct "Variable @{\"%s\"@} not found inside scope @{\"%a\"@}" variable ScopeName.format scope_uid | Some - (Desugared.Name_resolution.SubScope - (subscope_var_name, subscope_name)) -> ( + (Desugared.Name_resolution.SubScope (subscope_var_name, subscope_name)) + -> ( match second_part with | None -> Message.raise_error @@ -265,8 +265,10 @@ module Commands = struct SubScopeName.format subscope_var_name ScopeName.format scope_uid | Some second_part -> ( match - let ctxt = Desugared.Name_resolution.module_ctx ctxt - (List.map (fun m -> ModuleName.to_string m, Pos.no_pos) + let ctxt = + Desugared.Name_resolution.module_ctx ctxt + (List.map + (fun m -> ModuleName.to_string m, Pos.no_pos) (ScopeName.path subscope_name)) in Ident.Map.find_opt second_part diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index 694025c3..7af28c30 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -754,7 +754,7 @@ let translate_program (prgm : typed D.program) : untyped A.program = ctx_structs = prgm.decl_ctx.ctx_structs |> StructName.Map.mapi (fun _n str -> - StructField.Map.map trans_typ_keep str); + StructField.Map.map trans_typ_keep str); } in diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 07db3760..a7bc61aa 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -279,17 +279,17 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : the original module to ensure that bindlib performs the exact same renamings ; or finally we could normalise the names at generation time (either at toplevel or in a dedicated submodule ?) *) - let path = - match Mark.remove name with - | External_value name -> TopdefName.path name - | External_scope name -> ScopeName.path name - in - Uid.Path.format fmt path; + let path = match Mark.remove name with - | External_value name -> - format_var_str fmt (Mark.remove (TopdefName.get_info name)) - | External_scope name -> - format_var_str fmt (Mark.remove (ScopeName.get_info name))) + | External_value name -> TopdefName.path name + | External_scope name -> ScopeName.path name + in + Uid.Path.format fmt path; + match Mark.remove name with + | External_value name -> + format_var_str fmt (Mark.remove (TopdefName.get_info name)) + | External_scope name -> + format_var_str fmt (Mark.remove (ScopeName.get_info name))) | ETuple es -> Format.fprintf fmt "@[(%a)@]" (Format.pp_print_list diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index 62c42fbe..062ee90d 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -230,9 +230,8 @@ module To_jsoo = struct (StructField.Map.bindings struct_fields) fmt_conv_funs () in - let format_enum_decl - fmt - (enum_name, (enum_cons : typ EnumConstructor.Map.t)) = + let format_enum_decl fmt (enum_name, (enum_cons : typ EnumConstructor.Map.t)) + = let fmt_enum_name fmt _ = format_enum_name fmt enum_name in let fmt_module_enum_name fmt () = To_ocaml.format_to_module_name fmt (`Ename enum_name) diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index ace37150..aae3d97e 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -104,8 +104,7 @@ module To_json = struct (t :: acc) @ collect_required_type_defs_from_scope_input s | TEnum e -> List.fold_left collect (t :: acc) - (EnumConstructor.Map.values - (EnumName.Map.find e ctx.ctx_enums)) + (EnumConstructor.Map.values (EnumName.Map.find e ctx.ctx_enums)) | TArray t -> collect acc t | _ -> acc in diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index 90fba913..9e8bece1 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -43,8 +43,8 @@ let rec format_expr | EFunc v -> Format.fprintf fmt "%a" format_func_name v | EStruct (es, s) -> let fields = StructName.Map.find s decl_ctx.ctx_structs in - Format.fprintf fmt "@[%a@ %a%a%a@]" - StructName.format s Print.punctuation "{" + Format.fprintf fmt "@[%a@ %a%a%a@]" StructName.format s + Print.punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (e, (struct_field, _)) -> @@ -150,8 +150,8 @@ let rec format_statement ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt ((case, _), (arm_block, payload_name)) -> Format.fprintf fmt "%a %a%a@ %a @[%a@ %a@]" Print.punctuation - "|" Print.enum_constructor case Print.punctuation - ":" format_var_name payload_name Print.punctuation "→" + "|" Print.enum_constructor case Print.punctuation ":" + format_var_name payload_name Print.punctuation "→" (format_block decl_ctx ~debug) arm_block)) (List.combine (EnumConstructor.Map.bindings cons) arms) diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index ae8cff2f..43430f7e 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -440,9 +440,9 @@ let rec format_statement ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") (fun fmt (case_block, payload_var, cons_name) -> Format.fprintf fmt "%a.code == %a_Code.%a:@\n%a = %a.value@\n%a" - format_var tmp_var format_enum_name e_name - format_enum_cons_name cons_name format_var payload_var format_var - tmp_var (format_block ctx) case_block)) + format_var tmp_var format_enum_name e_name format_enum_cons_name + cons_name format_var payload_var format_var tmp_var + (format_block ctx) case_block)) cases | SReturn e1 -> Format.fprintf fmt "@[return %a@]" (format_expression ctx) diff --git a/compiler/scopelang/dependency.ml b/compiler/scopelang/dependency.ml index 0895b8c1..747351bf 100644 --- a/compiler/scopelang/dependency.ml +++ b/compiler/scopelang/dependency.ml @@ -98,7 +98,8 @@ let rule_used_defs = function expr_used_defs e | Ast.Call (subscope, subindex, _) -> if ScopeName.path subscope = [] then - VMap.singleton (Scope subscope) (Mark.get (SubScopeName.get_info subindex)) + VMap.singleton (Scope subscope) + (Mark.get (SubScopeName.get_info subindex)) else VMap.empty let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t = diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index 4c4fda30..5122d902 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -60,7 +60,11 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed = | ELocation (SubScopeVar { scope; alias; var }) -> (* When referring to a subscope variable in an expression, we are referring to the output, hence we take the last state. *) - let ctx = List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.modules) ctx (ScopeName.path scope) in + let ctx = + List.fold_left + (fun ctx m -> ModuleName.Map.find m ctx.modules) + ctx (ScopeName.path scope) + in let var = match ScopeVar.Map.find (Mark.remove var) ctx.scope_var_mapping with | WholeVar new_s_var -> Mark.copy var new_s_var diff --git a/compiler/scopelang/print.ml b/compiler/scopelang/print.ml index e8bf316e..6c6e3daa 100644 --- a/compiler/scopelang/print.ml +++ b/compiler/scopelang/print.ml @@ -23,9 +23,8 @@ let struc (fmt : Format.formatter) (name : StructName.t) (fields : typ StructField.Map.t) : unit = - Format.fprintf fmt "%a %a %a %a@\n@[ %a@]@\n%a" Print.keyword - "struct" StructName.format name Print.punctuation "=" - Print.punctuation "{" + Format.fprintf fmt "%a %a %a %a@\n@[ %a@]@\n%a" Print.keyword "struct" + StructName.format name Print.punctuation "=" Print.punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (field_name, typ) -> diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 43f30884..ce0b84a2 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -23,7 +23,6 @@ open Catala_utils module Runtime = Runtime_ocaml.Runtime module ModuleName = Uid.Module - module ScopeName = Uid.Gen_qualified () module TopdefName = Uid.Gen_qualified () module StructName = Uid.Gen_qualified () diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index 920f4fa7..51490cbb 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -109,10 +109,7 @@ let subst binder vars = Bindlib.msubst binder (Array.of_list (List.map Mark.remove vars)) let evar v mark = Mark.add mark (Bindlib.box_var v) - -let eexternal ~name mark = - Mark.add mark (Bindlib.box (EExternal { name })) - +let eexternal ~name mark = Mark.add mark (Bindlib.box (EExternal { name })) let etuple args = Box.appn args @@ fun args -> ETuple args let etupleaccess e index size = @@ -540,8 +537,7 @@ let compare_location SubScopeVar { alias = ysubindex, _; var = ysubvar, _; _ } ) -> let c = SubScopeName.compare xsubindex ysubindex in if c = 0 then ScopeVar.compare xsubvar ysubvar else c - | ( ToplevelVar { name = vx, _ }, - ToplevelVar { name = vy, _ } ) -> + | ToplevelVar { name = vx, _ }, ToplevelVar { name = vy, _ } -> TopdefName.compare vx vy | DesugaredScopeVar _, _ -> -1 | _, DesugaredScopeVar _ -> 1 @@ -614,9 +610,7 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = StructName.equal s1 s2 && StructField.Map.equal equal fields1 fields2 | ( EDStructAccess { e = e1; field = f1; name_opt = s1 }, EDStructAccess { e = e2; field = f2; name_opt = s2 } ) -> - Option.equal StructName.equal s1 s2 - && Ident.equal f1 f2 - && equal e1 e2 + Option.equal StructName.equal s1 s2 && Ident.equal f1 f2 && equal e1 e2 | ( EStructAccess { e = e1; field = f1; name = s1 }, EStructAccess { e = e2; field = f2; name = s2 } ) -> StructName.equal s1 s2 && StructField.equal f1 f2 && equal e1 e2 @@ -630,8 +624,7 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = && EnumConstructor.Map.equal equal cases1 cases2 | ( EScopeCall { scope = s1; args = fields1 }, EScopeCall { scope = s2; args = fields2 } ) -> - ScopeName.equal s1 s2 - && ScopeVar.Map.equal equal fields1 fields2 + ScopeName.equal s1 s2 && ScopeVar.Map.equal equal fields1 fields2 | ( ECustom { obj = obj1; targs = targs1; tret = tret1 }, ECustom { obj = obj2; targs = targs2; tret = tret2 } ) -> Type.equal_list targs1 targs2 && Type.equal tret1 tret2 && obj1 == obj2 diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index 1b53676e..b888ac98 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -549,7 +549,8 @@ let rec evaluate_expr : "free variable found at evaluation (should not happen if term was \ well-typed)" | EExternal { name } -> - let path = match Mark.remove name with + let path = + match Mark.remove name with | External_value td -> TopdefName.path td | External_scope s -> ScopeName.path s in @@ -565,8 +566,7 @@ let rec evaluate_expr : (TStruct scope_info.out_struct_name, pos) ), pos ) with TopdefName.Map.Not_found _ | ScopeName.Map.Not_found _ -> - Message.raise_spanned_error pos - "Reference to %a could not be resolved" + Message.raise_spanned_error pos "Reference to %a could not be resolved" Print.external_ref name in let runtime_path = diff --git a/compiler/shared_ast/print.ml b/compiler/shared_ast/print.ml index 9bc428a7..185ddad4 100644 --- a/compiler/shared_ast/print.ml +++ b/compiler/shared_ast/print.ml @@ -86,8 +86,7 @@ let location (type a) (fmt : Format.formatter) (l : a glocation) : unit = | SubScopeVar { alias = subindex; var = subvar; _ } -> Format.fprintf fmt "%a.%a" SubScopeName.format (Mark.remove subindex) ScopeVar.format (Mark.remove subvar) - | ToplevelVar { name } -> - TopdefName.format fmt (Mark.remove name) + | ToplevelVar { name } -> TopdefName.format fmt (Mark.remove name) let enum_constructor (fmt : Format.formatter) (c : EnumConstructor.t) : unit = Format.fprintf fmt "@{%a@}" EnumConstructor.format c @@ -131,11 +130,9 @@ let rec typ_gen | None -> StructName.format fmt s | Some ctx -> let fields = StructName.Map.find s ctx.ctx_structs in - if StructField.Map.is_empty fields then ( - StructName.format fmt s) + if StructField.Map.is_empty fields then StructName.format fmt s else - Format.fprintf fmt "@[%a %a@,%a@;<0 -2>%a@]" - StructName.format s + Format.fprintf fmt "@[%a %a@,%a@;<0 -2>%a@]" StructName.format s (pp_color_string (List.hd colors)) "{" (StructField.Map.format_bindings @@ -155,8 +152,7 @@ let rec typ_gen | None -> Format.fprintf fmt "@[%a@]" EnumName.format e | Some ctx -> let def = EnumName.Map.find e ctx.ctx_enums in - Format.fprintf fmt "@[%a%a%a%a@]" EnumName.format e - punctuation "[" + Format.fprintf fmt "@[%a%a%a%a@]" EnumName.format e punctuation "[" (EnumConstructor.Map.format_bindings ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " punctuation "|") (fun fmt pp_case mty -> @@ -517,8 +513,7 @@ module ExprGen (C : EXPR_PARAM) = struct else match Mark.remove e with | EVar v -> var fmt v - | EExternal { name } -> - external_ref fmt name + | EExternal { name } -> external_ref fmt name | ETuple es -> Format.fprintf fmt "@[%a%a%a@]" (pp_color_string (List.hd colors)) @@ -859,8 +854,8 @@ let enum fmt (pp_name : Format.formatter -> unit) (c : typ EnumConstructor.Map.t) = - Format.fprintf fmt "@[%a %t %a@ %a@]" keyword "type" pp_name - punctuation "=" + Format.fprintf fmt "@[%a %t %a@ %a@]" keyword "type" pp_name punctuation + "=" (EnumConstructor.Map.format_bindings ~pp_sep:(fun _ _ -> ()) (fun fmt pp_n ty -> diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index 728c73d7..b0c3ec6f 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -39,8 +39,7 @@ let empty_ctx = } let module_ctx ctx path = - List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.ctx_modules) - ctx path + List.fold_left (fun ctx m -> ModuleName.Map.find m ctx.ctx_modules) ctx path let get_scope_body { code_items; _ } scope = match diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index d1770225..85b1fe2b 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -129,10 +129,8 @@ let rec format_typ ts (pp_color_string (List.hd colors)) ")" - | TStruct s -> - A.StructName.format fmt s - | TEnum e -> - A.EnumName.format fmt e + | TStruct s -> A.StructName.format fmt s + | TEnum e -> A.EnumName.format fmt e | TOption t -> Format.fprintf fmt "@[option %a@]" (format_typ_with_parens ~colors:(List.tl colors)) @@ -344,7 +342,7 @@ module Env = struct Option.bind (A.ScopeName.Map.find_opt scope t.scopes) (fun vmap -> A.ScopeVar.Map.find_opt var vmap) - let rec module_env path env = + let module_env path env = List.fold_left (fun env m -> A.ModuleName.Map.find m env.modules) env path let add v tau t = { t with vars = Var.Map.add v tau t.vars } @@ -673,7 +671,8 @@ and typecheck_expr_top_down : in Expr.evar (Var.translate v) (mark_with_tau_and_unify tau') | A.EExternal { name } -> - let path = match Mark.remove name with + let path = + match Mark.remove name with | External_value td -> A.TopdefName.path td | External_scope s -> A.ScopeName.path s in @@ -681,8 +680,8 @@ and typecheck_expr_top_down : let ty = let not_found pr x = Message.raise_spanned_error pos_e - "Could not resolve the reference to %a.@ Make sure the \ - corresponding module was properly loaded?" + "Could not resolve the reference to %a.@ Make sure the corresponding \ + module was properly loaded?" pr x in match Mark.remove name with @@ -1026,29 +1025,29 @@ let program ~leave_unresolved prg = prg.decl_ctx with ctx_structs = A.StructName.Map.mapi - (fun s_name (fields) -> - ( A.StructField.Map.mapi - (fun f_name (t : A.typ) -> - match Mark.remove t with - | TAny -> - typ_to_ast ~leave_unresolved - (A.StructField.Map.find f_name - (A.StructName.Map.find s_name new_env.structs)) - | _ -> t) - fields )) + (fun s_name fields -> + A.StructField.Map.mapi + (fun f_name (t : A.typ) -> + match Mark.remove t with + | TAny -> + typ_to_ast ~leave_unresolved + (A.StructField.Map.find f_name + (A.StructName.Map.find s_name new_env.structs)) + | _ -> t) + fields) prg.decl_ctx.ctx_structs; ctx_enums = A.EnumName.Map.mapi - (fun e_name (cons) -> - ( A.EnumConstructor.Map.mapi - (fun cons_name (t : A.typ) -> - match Mark.remove t with - | TAny -> - typ_to_ast ~leave_unresolved - (A.EnumConstructor.Map.find cons_name - (A.EnumName.Map.find e_name new_env.enums)) - | _ -> t) - cons )) + (fun e_name cons -> + A.EnumConstructor.Map.mapi + (fun cons_name (t : A.typ) -> + match Mark.remove t with + | TAny -> + typ_to_ast ~leave_unresolved + (A.EnumConstructor.Map.find cons_name + (A.EnumName.Map.find e_name new_env.enums)) + | _ -> t) + cons) prg.decl_ctx.ctx_enums; }; }