From 72882f82dfc75888470a9415a5b51a7ab38e140e Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Wed, 16 Aug 2023 00:04:45 +0200 Subject: [PATCH] Reformat --- build_system/clerk_driver.ml | 1 - build_system/clerk_runtest.ml | 12 +- build_system/clerk_runtest.mli | 18 +- compiler/catala_utils/map.ml | 11 +- compiler/dcalc/from_scopelang.ml | 401 ++++++++++--------- compiler/desugared/ast.ml | 3 +- compiler/desugared/dependency.ml | 3 +- compiler/desugared/disambiguate.ml | 30 +- compiler/desugared/from_surface.ml | 286 +++++++------ compiler/desugared/linting.ml | 3 +- compiler/desugared/name_resolution.ml | 76 ++-- compiler/desugared/name_resolution.mli | 6 +- compiler/driver.ml | 8 +- compiler/lcalc/ast.ml | 7 +- compiler/lcalc/compile_without_exceptions.ml | 9 +- compiler/lcalc/to_ocaml.ml | 30 +- compiler/plugins/api_web.ml | 93 ++--- compiler/plugins/json_schema.ml | 6 +- compiler/plugins/lazy_interp.ml | 19 +- compiler/plugins/modules.ml | 2 +- compiler/scalc/print.ml | 12 +- compiler/scalc/to_python.ml | 10 +- compiler/scopelang/ast.ml | 41 +- compiler/scopelang/ast.mli | 4 +- compiler/scopelang/dependency.ml | 14 +- compiler/scopelang/from_desugared.ml | 313 +++++++-------- compiler/scopelang/print.ml | 18 +- compiler/shared_ast/definitions.ml | 47 ++- compiler/shared_ast/expr.ml | 63 +-- compiler/shared_ast/expr.mli | 37 +- compiler/shared_ast/interpreter.ml | 48 +-- compiler/shared_ast/optimizations.ml | 15 +- compiler/shared_ast/print.ml | 44 +- compiler/shared_ast/program.ml | 7 +- compiler/shared_ast/program.mli | 3 +- compiler/shared_ast/scope.mli | 4 +- compiler/shared_ast/typing.ml | 113 +++--- compiler/shared_ast/var.ml | 1 + compiler/shared_ast/var.mli | 2 +- compiler/surface/ast.ml | 3 +- compiler/surface/parser_driver.ml | 8 +- compiler/surface/parser_driver.mli | 11 +- compiler/verification/z3backend.real.ml | 12 +- 43 files changed, 974 insertions(+), 880 deletions(-) diff --git a/build_system/clerk_driver.ml b/build_system/clerk_driver.ml index 903288e7..dc9d8b0b 100644 --- a/build_system/clerk_driver.ml +++ b/build_system/clerk_driver.ml @@ -173,7 +173,6 @@ let readdir_sort (dirname : string) : string array = dirs with Sys_error _ -> [||] - (** Given a file, looks in the relative [output] directory if there are files with the same base name that contain expected outputs for different *) let search_for_expected_outputs (file : string) : expected_output_descr list = diff --git a/build_system/clerk_runtest.ml b/build_system/clerk_runtest.ml index 103d4e0c..48db9619 100644 --- a/build_system/clerk_runtest.ml +++ b/build_system/clerk_runtest.ml @@ -40,7 +40,7 @@ let checkfile parents file = ~pp_sep:(fun ppf () -> Format.fprintf ppf " %a@ " String.format "→") Format.pp_print_string) (List.rev (file :: parents)); - (file :: parents), file + file :: parents, file let with_in_channel_safe parents file f = try File.with_in_channel file f @@ -186,9 +186,9 @@ let run_inline_tests | [] -> Message.emit_warning "No inline tests found in %s" file | file_tests -> Message.emit_debug "@[Running tests:@ %a@]" - (Format.pp_print_list - (fun ppf t -> Format.fprintf ppf "- @[%s:@ %d tests@]" - t.filename (List.length t.tests))) + (Format.pp_print_list (fun ppf t -> + Format.fprintf ppf "- @[%s:@ %d tests@]" t.filename + (List.length t.tests))) file_tests; let run test oc = List.iter @@ -214,7 +214,8 @@ let run_inline_tests let pid = let cwd = Unix.getcwd () in Unix.chdir file_dir; - Fun.protect ~finally:(fun () -> Unix.chdir cwd) @@ fun () -> + Fun.protect ~finally:(fun () -> Unix.chdir cwd) + @@ fun () -> Unix.create_process_env catala_exe cmd env Unix.stdin cmd_out_wr cmd_out_wr in @@ -256,4 +257,3 @@ let run_inline_tests Sys.rename out test.filename) else run test stdout) file_tests - diff --git a/build_system/clerk_runtest.mli b/build_system/clerk_runtest.mli index e1e4c3c6..cc9e4d17 100644 --- a/build_system/clerk_runtest.mli +++ b/build_system/clerk_runtest.mli @@ -14,10 +14,18 @@ License for the specific language governing permissions and limitations under the License. *) -(** This module contains specific commands used to detect and run inline tests in Catala files. The functionality is built into the `clerk runtest` subcommand, but is separate from the normal Clerk behaviour: Clerk drives Ninja, which in turn might need to evaluate tests as part of some rules and can run `clerk runtest` in a reentrant way. *) +(** This module contains specific commands used to detect and run inline tests + in Catala files. The functionality is built into the `clerk runtest` + subcommand, but is separate from the normal Clerk behaviour: Clerk drives + Ninja, which in turn might need to evaluate tests as part of some rules and + can run `clerk runtest` in a reentrant way. *) -val has_inline_tests: string -> bool -(** Checks if the given named file contains inline tests (either directly or through includes) *) +val has_inline_tests : string -> bool +(** Checks if the given named file contains inline tests (either directly or + through includes) *) -val run_inline_tests: reset:bool -> string -> string -> string list -> unit -(** [run_inline_tests ~reset file catala_exe catala_opts] runs the tests in Catala [file] using the given path to the Catala executable and the provided options. Output is printed to [stdout] if [reset] is false, otherwise [file] is replaced with the updated test results. *) +val run_inline_tests : reset:bool -> string -> string -> string list -> unit +(** [run_inline_tests ~reset file catala_exe catala_opts] runs the tests in + Catala [file] using the given path to the Catala executable and the provided + options. Output is printed to [stdout] if [reset] is false, otherwise [file] + is replaced with the updated test results. *) diff --git a/compiler/catala_utils/map.ml b/compiler/catala_utils/map.ml index 28c088eb..2d56a8c2 100644 --- a/compiler/catala_utils/map.ml +++ b/compiler/catala_utils/map.ml @@ -32,8 +32,7 @@ module type S = sig exception Not_found of key (* Slightly more informative [Not_found] exception *) - val find: key -> 'a t -> 'a - + val find : key -> 'a t -> 'a val keys : 'a t -> key list val values : 'a t -> 'a list val of_list : (key * 'a) list -> 'a t @@ -70,7 +69,6 @@ module type S = sig unit (** Formats all bindings of the map in order using the given separator (default ["; "]) and binding indicator (default [" = "]). *) - end module Make (Ord : OrderedType) : S with type key = Ord.t = struct @@ -79,14 +77,13 @@ module Make (Ord : OrderedType) : S with type key = Ord.t = struct exception Not_found of key let () = - Printexc.register_printer @@ function + Printexc.register_printer + @@ function | Not_found k -> Some (Format.asprintf "key '%a' not found in map" Ord.format k) | _ -> None - let find k t = - try find k t with Stdlib.Not_found -> raise (Not_found k) - + let find k t = try find k t with Stdlib.Not_found -> raise (Not_found k) let keys t = fold (fun k _ acc -> k :: acc) t [] |> List.rev let values t = fold (fun _ v acc -> v :: acc) t [] |> List.rev let of_list l = List.fold_left (fun m (k, v) -> add k v m) empty l diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index 131af240..4e8e6f59 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -35,7 +35,8 @@ type 'm scope_ref = type 'm scope_sig_ctx = { scope_sig_local_vars : scope_var_ctx list; (** List of scope variables *) - scope_sig_scope_ref : 'm scope_ref; (** Var or external representing the scope *) + scope_sig_scope_ref : 'm scope_ref; + (** Var or external representing the scope *) scope_sig_input_struct : StructName.t; (** Scope input *) scope_sig_output_struct : StructName.t; (** Scope output *) scope_sig_in_fields : scope_input_var_ctx ScopeVar.Map.t; @@ -43,8 +44,8 @@ type 'm scope_sig_ctx = { } type 'm scope_sigs_ctx = { - scope_sigs: 'm scope_sig_ctx ScopeName.Map.t; - scope_sigs_modules: 'm scope_sigs_ctx ModuleName.Map.t; + scope_sigs : 'm scope_sig_ctx ScopeName.Map.t; + scope_sigs_modules : 'm scope_sigs_ctx ModuleName.Map.t; } type 'm ctx = { @@ -75,11 +76,12 @@ let pos_mark_mk (type a m) (e : (a, m) gexpr) : let rec module_scope_sig scope_sig_ctx path scope = match path with | [] -> ScopeName.Map.find scope scope_sig_ctx.scope_sigs - | (modname, mpos) :: path -> + | (modname, mpos) :: path -> ( match ModuleName.Map.find_opt modname scope_sig_ctx.scope_sigs_modules with | None -> - Message.raise_spanned_error mpos "Module %a not found" ModuleName.format modname - | Some sig_ctx -> module_scope_sig sig_ctx path scope + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format + modname + | Some sig_ctx -> module_scope_sig sig_ctx path scope) let merge_defaults ~(is_func : bool) @@ -223,7 +225,8 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : Message.raise_spanned_error (Expr.pos e) "The constructor %a of enum %a%a is missing from this pattern \ matching" - EnumConstructor.format constructor Print.path path EnumName.format name + EnumConstructor.format constructor Print.path path + EnumName.format name in let case_d = translate_expr ctx case_e in ( EnumConstructor.Map.add constructor case_d d_cases, @@ -234,8 +237,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : if not (EnumConstructor.Map.is_empty remaining_e_cases) then Message.raise_spanned_error (Expr.pos e) "Pattern matching is incomplete for enum %a%a: missing cases %a" - Print.path path - EnumName.format name + Print.path path EnumName.format name (EnumConstructor.Map.format_keys ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ")) remaining_e_cases; @@ -243,9 +245,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : Expr.ematch ~e:e1 ~name ~cases:d_cases m | EScopeCall { path; scope; args } -> let pos = Expr.mark_pos m in - let sc_sig = - module_scope_sig ctx.scopes_parameters path scope - in + let sc_sig = module_scope_sig ctx.scopes_parameters path scope in let in_var_map = ScopeVar.Map.merge (fun var_name (str_field : scope_input_var_ctx option) expr -> @@ -292,18 +292,20 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : in_var_map StructField.Map.empty in let arg_struct = - Expr.estruct ~name:sc_sig.scope_sig_input_struct ~fields:field_map (mark_tany m pos) + Expr.estruct ~name:sc_sig.scope_sig_input_struct ~fields:field_map + (mark_tany m pos) in let called_func = let m = mark_tany m pos in - let e = match sc_sig.scope_sig_scope_ref with + let e = + match sc_sig.scope_sig_scope_ref with | Local_scope_ref v -> Expr.evar v m | External_scope_ref (path, name) -> - Expr.eexternal ~path ~name:(Mark.map (fun s -> External_scope s) name) m + Expr.eexternal ~path + ~name:(Mark.map (fun s -> External_scope s) name) + m in - tag_with_log_entry - e - BeginCall + tag_with_log_entry e BeginCall [ScopeName.get_info scope; Mark.add (Expr.pos e) "direct"] in let single_arg = @@ -351,61 +353,67 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : (struct_output.struct_output_function_field x) ... } *) let result_eta_expanded = Expr.estruct ~name:sc_sig.scope_sig_output_struct - ~fields:(StructField.Map.mapi - (fun field typ -> - let original_field_expr = - Expr.estructaccess - ~e:(Expr.make_var result_var - (Expr.with_ty m - (TStruct sc_sig.scope_sig_output_struct, Expr.pos e))) - ~field ~name:sc_sig.scope_sig_output_struct (Expr.with_ty m typ) - in - match Mark.remove typ with - | TArrow (ts_in, t_out) -> - (* Here the output scope struct field is a function so we - eta-expand it and insert logging instructions. Invariant: - works because there is no partial evaluation. *) - let params_vars = - ListLabels.mapi ts_in ~f:(fun i _ -> - Var.make ("param" ^ string_of_int i)) + ~fields: + (StructField.Map.mapi + (fun field typ -> + let original_field_expr = + Expr.estructaccess + ~e: + (Expr.make_var result_var + (Expr.with_ty m + (TStruct sc_sig.scope_sig_output_struct, Expr.pos e))) + ~field ~name:sc_sig.scope_sig_output_struct + (Expr.with_ty m typ) in - let f_markings = - [ScopeName.get_info scope; StructField.get_info field] - in - Expr.make_abs - (Array.of_list params_vars) - (tag_with_log_entry - (tag_with_log_entry - (Expr.eapp - (tag_with_log_entry original_field_expr BeginCall - f_markings) - (ListLabels.mapi (List.combine params_vars ts_in) - ~f:(fun i (param_var, t_in) -> - tag_with_log_entry - (Expr.make_var param_var (Expr.with_ty m t_in)) - (VarDef - { - log_typ = Mark.remove t_in; - log_io_output = false; - log_io_input = OnlyInput; - }) - (f_markings - @ [ - Mark.add (Expr.pos e) - ("input" ^ string_of_int i); - ]))) - (Expr.with_ty m t_out)) - (VarDef - { - log_typ = Mark.remove t_out; - log_io_output = true; - log_io_input = NoInput; - }) - (f_markings @ [Mark.add (Expr.pos e) "output"])) - EndCall f_markings) - ts_in (Expr.pos e) - | _ -> original_field_expr) - (snd (StructName.Map.find sc_sig.scope_sig_output_struct ctx.decl_ctx.ctx_structs))) + match Mark.remove typ with + | TArrow (ts_in, t_out) -> + (* Here the output scope struct field is a function so we + eta-expand it and insert logging instructions. Invariant: + works because there is no partial evaluation. *) + let params_vars = + ListLabels.mapi ts_in ~f:(fun i _ -> + Var.make ("param" ^ string_of_int i)) + in + let f_markings = + [ScopeName.get_info scope; StructField.get_info field] + in + Expr.make_abs + (Array.of_list params_vars) + (tag_with_log_entry + (tag_with_log_entry + (Expr.eapp + (tag_with_log_entry original_field_expr BeginCall + f_markings) + (ListLabels.mapi (List.combine params_vars ts_in) + ~f:(fun i (param_var, t_in) -> + tag_with_log_entry + (Expr.make_var param_var + (Expr.with_ty m t_in)) + (VarDef + { + log_typ = Mark.remove t_in; + log_io_output = false; + log_io_input = OnlyInput; + }) + (f_markings + @ [ + Mark.add (Expr.pos e) + ("input" ^ string_of_int i); + ]))) + (Expr.with_ty m t_out)) + (VarDef + { + log_typ = Mark.remove t_out; + log_io_output = true; + log_io_input = NoInput; + }) + (f_markings @ [Mark.add (Expr.pos e) "output"])) + EndCall f_markings) + ts_in (Expr.pos e) + | _ -> original_field_expr) + (snd + (StructName.Map.find sc_sig.scope_sig_output_struct + ctx.decl_ctx.ctx_structs))) (Expr.with_ty m (TStruct sc_sig.scope_sig_output_struct, Expr.pos e)) in (* Here we have to go through an if statement that records a decision being @@ -457,9 +465,9 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : match ctx.scope_name, Mark.remove f with | Some sname, ELocation loc -> ( match loc with - | ScopelangScopeVar { name = (v, _); _ } -> + | ScopelangScopeVar { name = v, _; _ } -> [ScopeName.get_info sname; ScopeVar.get_info v] - | SubScopeVar {scope; var = (v, _); _} -> + | SubScopeVar { scope; var = v, _; _ } -> [ScopeName.get_info scope; ScopeVar.get_info v] | ToplevelVar _ -> []) | _ -> [] @@ -483,20 +491,20 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : | _ -> ListLabels.map new_args ~f:(fun _ -> TAny), TAny in match Mark.remove f with - | ELocation (ScopelangScopeVar {name = var}) -> + | ELocation (ScopelangScopeVar { name = var }) -> retrieve_in_and_out_typ_or_any var ctx.scope_vars - | ELocation (SubScopeVar { alias; var ; _}) -> + | ELocation (SubScopeVar { alias; var; _ }) -> ctx.subscope_vars |> SubScopeName.Map.find (Mark.remove alias) |> retrieve_in_and_out_typ_or_any var | ELocation (ToplevelVar { path; name }) -> ( - let decl_ctx = Program.module_ctx ctx.decl_ctx path in - let typ = TopdefName.Map.find (Mark.remove name) decl_ctx.ctx_topdefs in - match Mark.remove typ with - | TArrow (tin, (tout, _)) -> List.map Mark.remove tin, tout - | _ -> - Message.raise_spanned_error (Expr.pos e) - "Application of non-function toplevel variable") + let decl_ctx = Program.module_ctx ctx.decl_ctx path in + let typ = TopdefName.Map.find (Mark.remove name) decl_ctx.ctx_topdefs in + match Mark.remove typ with + | TArrow (tin, (tout, _)) -> List.map Mark.remove tin, tout + | _ -> + Message.raise_spanned_error (Expr.pos e) + "Application of non-function toplevel variable") | _ -> ListLabels.map new_args ~f:(fun _ -> TAny), TAny in @@ -567,14 +575,14 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : | ELocation (ToplevelVar { path = []; name }) -> let v, _ = TopdefName.Map.find (Mark.remove name) ctx.toplevel_vars in Expr.evar v m - | ELocation (ToplevelVar { path = _::_ as path; name }) -> + | ELocation (ToplevelVar { path = _ :: _ as path; name }) -> Expr.eexternal ~path ~name:(Mark.map (fun n -> External_value n) name) m | EOp { op = Add_dat_dur _; tys } -> Expr.eop (Add_dat_dur ctx.date_rounding) tys m | EOp { op; tys } -> Expr.eop (Operator.translate op) tys m - | ( EVar _ | EAbs _ | ELit _ | EStruct _ | EStructAccess _ - | ETuple _ | ETupleAccess _ | EInj _ | EEmptyError | EErrorOnEmpty _ - | EArray _ | EIfThenElse _ ) as e -> + | ( EVar _ | EAbs _ | ELit _ | EStruct _ | EStructAccess _ | ETuple _ + | ETupleAccess _ | EInj _ | EEmptyError | EErrorOnEmpty _ | EArray _ + | EIfThenElse _ ) as e -> Expr.map ~f:(translate_expr ctx) (e, m) (** The result of a rule translation is a list of assignment, with variables and @@ -781,9 +789,7 @@ let translate_rule all_subscope_output_vars in let subscope_func = - tag_with_log_entry - scope_dcalc_ref - BeginCall + tag_with_log_entry scope_dcalc_ref BeginCall [ sigma_name, pos_sigma; SubScopeName.get_info subindex; @@ -896,20 +902,21 @@ let translate_rules ((fun next -> next), ctx) rules in - let scope_sig_decl = - ScopeName.Map.find scope_name ctx.decl_ctx.ctx_scopes - in + let scope_sig_decl = ScopeName.Map.find scope_name ctx.decl_ctx.ctx_scopes in let return_exp = Expr.estruct ~name:scope_sig.scope_sig_output_struct - ~fields:(ScopeVar.Map.fold - (fun var (dcalc_var, _, io) acc -> - if Mark.remove io.Desugared.Ast.io_output then - let field = ScopeVar.Map.find var scope_sig_decl.out_struct_fields in - StructField.Map.add field - (Expr.make_var dcalc_var (mark_tany mark pos_sigma)) - acc - else acc) - new_ctx.scope_vars StructField.Map.empty) + ~fields: + (ScopeVar.Map.fold + (fun var (dcalc_var, _, io) acc -> + if Mark.remove io.Desugared.Ast.io_output then + let field = + ScopeVar.Map.find var scope_sig_decl.out_struct_fields + in + StructField.Map.add field + (Expr.make_var dcalc_var (mark_tany mark pos_sigma)) + acc + else acc) + new_ctx.scope_vars StructField.Map.empty) (mark_tany mark pos_sigma) in ( scope_lets @@ -918,7 +925,8 @@ let translate_rules (Expr.Box.lift return_exp)), new_ctx ) -(* From a scope declaration and definitions, create the corresponding scope body wrapped in the appropriate call convention. *) +(* From a scope declaration and definitions, create the corresponding scope body + wrapped in the appropriate call convention. *) let translate_scope_decl (ctx : 'm ctx) (scope_name : ScopeName.t) @@ -972,14 +980,16 @@ let translate_scope_decl (* Find a witness of a mark in the definitions *) match sigma.scope_decl_rules with | [] -> - (* Todo: are we sure this can't happen in normal code ? E.g. is calling a scope which only defines input variables already an error at this stage or not ? *) - Message.raise_spanned_error pos_sigma "Scope %a has no content" ScopeName.format scope_name - | (Definition (_,_,_,(_,m)) | Assertion (_,m) | Call (_,_,m)) :: _ -> + (* Todo: are we sure this can't happen in normal code ? E.g. is calling a + scope which only defines input variables already an error at this stage + or not ? *) + Message.raise_spanned_error pos_sigma "Scope %a has no content" + ScopeName.format scope_name + | (Definition (_, _, _, (_, m)) | Assertion (_, m) | Call (_, _, m)) :: _ -> m in let rules_with_return_expr, ctx = - translate_rules ctx scope_name sigma.scope_decl_rules sigma_info - scope_mark + translate_rules ctx scope_name sigma.scope_decl_rules sigma_info scope_mark scope_sig in let scope_variables = @@ -1034,18 +1044,17 @@ let translate_scope_decl }) (Bindlib.bind_var v next) (Expr.Box.lift - (Expr.make_var scope_input_var - (mark_tany scope_mark pos_sigma)))) + (Expr.make_var scope_input_var (mark_tany scope_mark pos_sigma)))) scope_input_variables next in let scope_body = Bindlib.box_apply (fun scope_body_expr -> - { - scope_body_expr; - scope_body_input_struct = scope_input_struct_name; - scope_body_output_struct = scope_return_struct_name; - }) + { + scope_body_expr; + scope_body_input_struct = scope_input_struct_name; + scope_body_output_struct = scope_return_struct_name; + }) (Bindlib.bind_var scope_input_var (input_destructurings rules_with_return_expr)) in @@ -1062,8 +1071,7 @@ let translate_scope_decl let new_struct_ctx = StructName.Map.singleton scope_input_struct_name ([], field_map) in - ( scope_body, - new_struct_ctx ) + scope_body, new_struct_ctx let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = let defs_dependencies = Scopelang.Dependency.build_program_dep_graph prgm in @@ -1073,102 +1081,132 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = in let decl_ctx = prgm.program_ctx in Message.emit_debug "prog scopes: %a@ modules: %a" - (ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space) prgm.program_scopes - (ModuleName.Map.format - (fun fmt prg -> ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space fmt prg.Scopelang.Ast.program_scopes)) prgm.program_modules; + (ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space) + prgm.program_scopes + (ModuleName.Map.format (fun fmt prg -> + ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space fmt + prg.Scopelang.Ast.program_scopes)) + prgm.program_modules; let sctx : 'm scope_sigs_ctx = let process_scope_sig (scope_path, scope_name) scope = - Message.emit_debug "process_scope_sig %a%a (%a)" - Print.path scope_path ScopeName.format scope_name ScopeName.format scope.Scopelang.Ast.scope_decl_name; + Message.emit_debug "process_scope_sig %a%a (%a)" Print.path scope_path + ScopeName.format scope_name ScopeName.format + scope.Scopelang.Ast.scope_decl_name; let scope_ref = match scope_path with | [] -> let v = Var.make (Mark.remove (ScopeName.get_info scope_name)) in Local_scope_ref v | path -> - External_scope_ref (path, Mark.copy (ScopeName.get_info scope_name) scope_name) + External_scope_ref + (path, Mark.copy (ScopeName.get_info scope_name) scope_name) in let scope_info = try - ScopeName.Map.find scope_name (Program.module_ctx decl_ctx scope_path).ctx_scopes - with ScopeName.Map.Not_found _ -> Message.raise_spanned_error (Mark.get (ScopeName.get_info scope_name)) "Could not find scope %a%a" Print.path scope_path ScopeName.format scope_name + ScopeName.Map.find scope_name + (Program.module_ctx decl_ctx scope_path).ctx_scopes + with ScopeName.Map.Not_found _ -> + Message.raise_spanned_error + (Mark.get (ScopeName.get_info scope_name)) + "Could not find scope %a%a" Print.path scope_path ScopeName.format + scope_name in let scope_sig_in_fields = - (* Output fields have already been generated and added to the program ctx at this point, because they are visible to the user (manipulated as the return type of ScopeCalls) ; but input fields are used purely internally and need to be created here to implement the call convention for scopes. *) - ScopeVar.Map.filter_map - (fun dvar (typ, vis) -> - match Mark.remove vis.Desugared.Ast.io_input with - | NoInput -> None - | OnlyInput | Reentrant -> - let info = ScopeVar.get_info dvar in - let s = Mark.remove info ^ "_in" in - Some - { - scope_input_name = StructField.fresh (s, Mark.get info); - scope_input_io = vis.Desugared.Ast.io_input; - scope_input_typ = Mark.remove typ; - }) - scope.Scopelang.Ast.scope_sig - in - { - scope_sig_local_vars = - List.map - (fun (scope_var, (tau, vis)) -> + (* Output fields have already been generated and added to the program + ctx at this point, because they are visible to the user (manipulated + as the return type of ScopeCalls) ; but input fields are used purely + internally and need to be created here to implement the call + convention for scopes. *) + ScopeVar.Map.filter_map + (fun dvar (typ, vis) -> + match Mark.remove vis.Desugared.Ast.io_input with + | NoInput -> None + | OnlyInput | Reentrant -> + let info = ScopeVar.get_info dvar in + let s = Mark.remove info ^ "_in" in + Some { - scope_var_name = scope_var; - scope_var_typ = Mark.remove tau; - scope_var_io = vis; + scope_input_name = StructField.fresh (s, Mark.get info); + scope_input_io = vis.Desugared.Ast.io_input; + scope_input_typ = Mark.remove typ; }) - (ScopeVar.Map.bindings scope.scope_sig); - scope_sig_scope_ref = scope_ref; - scope_sig_input_struct = scope_info.in_struct_name; - scope_sig_output_struct = scope_info.out_struct_name; - scope_sig_in_fields; - } + scope.Scopelang.Ast.scope_sig + in + { + scope_sig_local_vars = + List.map + (fun (scope_var, (tau, vis)) -> + { + scope_var_name = scope_var; + scope_var_typ = Mark.remove tau; + scope_var_io = vis; + }) + (ScopeVar.Map.bindings scope.scope_sig); + scope_sig_scope_ref = scope_ref; + scope_sig_input_struct = scope_info.in_struct_name; + scope_sig_output_struct = scope_info.out_struct_name; + scope_sig_in_fields; + } in let rec process_modules path prg = - { scope_sigs = + { + scope_sigs = ScopeName.Map.mapi (fun scope_name (scope_decl, _) -> - process_scope_sig (path, scope_name) scope_decl) + process_scope_sig (path, scope_name) scope_decl) prg.Scopelang.Ast.program_scopes; scope_sigs_modules = - ModuleName.Map.mapi (fun modname prg -> + ModuleName.Map.mapi + (fun modname prg -> process_modules (path @ [modname, Pos.no_pos]) prg) prg.Scopelang.Ast.program_modules; } in - { scope_sigs = + { + scope_sigs = ScopeName.Map.mapi (fun scope_name (scope_decl, _) -> - process_scope_sig ([], scope_name) scope_decl) + process_scope_sig ([], scope_name) scope_decl) prgm.Scopelang.Ast.program_scopes; scope_sigs_modules = - ModuleName.Map.mapi (fun modname prg -> - process_modules [modname, Pos.no_pos] prg) + ModuleName.Map.mapi + (fun modname prg -> process_modules [modname, Pos.no_pos] prg) prgm.Scopelang.Ast.program_modules; } in let rec gather_module_in_structs acc path sctx = (* Expose all added in_structs from submodules at toplevel *) - ModuleName.Map.fold (fun modname scope_sigs acc -> + ModuleName.Map.fold + (fun modname scope_sigs acc -> let path = path @ [modname, Pos.no_pos] in - let acc = gather_module_in_structs acc path scope_sigs.scope_sigs_modules in - ScopeName.Map.fold (fun _ scope_sig_ctx acc -> + let acc = + gather_module_in_structs acc path scope_sigs.scope_sigs_modules + in + ScopeName.Map.fold + (fun _ scope_sig_ctx acc -> let fields = - ScopeVar.Map.fold (fun _ sivc acc -> - let pos = Mark.get (StructField.get_info sivc.scope_input_name) in - StructField.Map.add sivc.scope_input_name (sivc.scope_input_typ, pos) acc) + ScopeVar.Map.fold + (fun _ sivc acc -> + let pos = + Mark.get (StructField.get_info sivc.scope_input_name) + in + StructField.Map.add sivc.scope_input_name + (sivc.scope_input_typ, pos) + acc) scope_sig_ctx.scope_sig_in_fields StructField.Map.empty in StructName.Map.add scope_sig_ctx.scope_sig_input_struct (path, fields) acc) - scope_sigs.scope_sigs acc - ) - sctx - acc + scope_sigs.scope_sigs acc) + sctx acc + in + let decl_ctx = + { + decl_ctx with + ctx_structs = + gather_module_in_structs decl_ctx.ctx_structs [] sctx.scope_sigs_modules; + } in - let decl_ctx = { decl_ctx with ctx_structs = gather_module_in_structs decl_ctx.ctx_structs [] sctx.scope_sigs_modules } in let top_ctx = let toplevel_vars = TopdefName.Map.mapi @@ -1208,19 +1246,23 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = translate_scope_decl ctx scope_name (Mark.remove scope) in let scope_var = - match (ScopeName.Map.find scope_name sctx.scope_sigs).scope_sig_scope_ref with + match + (ScopeName.Map.find scope_name sctx.scope_sigs) + .scope_sig_scope_ref + with | Local_scope_ref v -> v | External_scope_ref _ -> assert false in ( { ctx with decl_ctx = - { ctx.decl_ctx with + { + ctx.decl_ctx with ctx_structs = StructName.Map.union (fun _ _ -> assert false) ctx.decl_ctx.ctx_structs scope_in_struct; - } + }; }, scope_var, Bindlib.box_apply @@ -1235,8 +1277,7 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program = ctx ) in let items, ctx = translate_defs top_ctx defs_ordering in - (* WIP TODO FIXME HERE: the scopes in submodules are not translated here it seems, and their input structs not added to decl_ctx (see From_surface:1476 for decl_ctx flattening info) *) - { - code_items = Bindlib.unbox items; - decl_ctx = ctx.decl_ctx; - } + (* WIP TODO FIXME HERE: the scopes in submodules are not translated here it + seems, and their input structs not added to decl_ctx (see From_surface:1476 + for decl_ctx flattening info) *) + { code_items = Bindlib.unbox items; decl_ctx = ctx.decl_ctx } diff --git a/compiler/desugared/ast.ml b/compiler/desugared/ast.ml index 734037b0..6569579f 100644 --- a/compiler/desugared/ast.ml +++ b/compiler/desugared/ast.ml @@ -248,7 +248,8 @@ let free_variables (def : rule RuleName.Map.t) : Pos.t ScopeDef.Map.t = (fun (loc, loc_pos) acc -> let usage = match loc with - | DesugaredScopeVar { name; state } -> Some (ScopeDef.Var (Mark.remove name, state)) + | DesugaredScopeVar { name; state } -> + Some (ScopeDef.Var (Mark.remove name, state)) | SubScopeVar { alias; var; _ } -> Some (ScopeDef.SubScopeVar diff --git a/compiler/desugared/dependency.ml b/compiler/desugared/dependency.ml index 6905c2af..6cc72798 100644 --- a/compiler/desugared/dependency.ml +++ b/compiler/desugared/dependency.ml @@ -261,7 +261,8 @@ let build_scope_dependencies (scope : Ast.scope) : ScopeDependencies.t = (fun used_var g -> let edge_from = match Mark.remove used_var with - | DesugaredScopeVar { name; state } -> Some (Vertex.Var (Mark.remove name, state)) + | DesugaredScopeVar { name; state } -> + Some (Vertex.Var (Mark.remove name, state)) | SubScopeVar { alias; _ } -> Some (Vertex.SubScope (Mark.remove alias)) | ToplevelVar _ -> None diff --git a/compiler/desugared/disambiguate.ml b/compiler/desugared/disambiguate.ml index ba4946e6..eacf6124 100644 --- a/compiler/desugared/disambiguate.ml +++ b/compiler/desugared/disambiguate.ml @@ -72,30 +72,30 @@ let program prg = let env = ScopeName.Map.fold (fun scope_name scope env -> - let vars = - ScopeDef.Map.fold - (fun var def vars -> - match var with - | Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars - | SubScopeVar _ -> vars) - scope.scope_defs ScopeVar.Map.empty - in - Typing.Env.add_scope scope_name ~vars env) + let vars = + ScopeDef.Map.fold + (fun var def vars -> + match var with + | Var (v, _states) -> ScopeVar.Map.add v def.scope_def_typ vars + | SubScopeVar _ -> vars) + scope.scope_defs ScopeVar.Map.empty + in + Typing.Env.add_scope scope_name ~vars env) prg.program_scopes env in env in let rec build_typing_env prg = - ModuleName.Map.fold (fun modname prg -> + ModuleName.Map.fold + (fun modname prg -> Typing.Env.add_module modname ~module_env:(build_typing_env prg)) - prg.program_modules - (base_typing_env prg) + prg.program_modules (base_typing_env prg) in let env = - ModuleName.Map.fold (fun modname prg -> + ModuleName.Map.fold + (fun modname prg -> Typing.Env.add_module modname ~module_env:(build_typing_env prg)) - prg.program_modules - (base_typing_env prg) + prg.program_modules (base_typing_env prg) in let program_topdefs = TopdefName.Map.map diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index 8dfd9011..05b4761f 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -160,21 +160,24 @@ let rec disambiguate_constructor possible_c_uids; EnumName.Map.choose possible_c_uids | [enum] -> ( - (* The path is fully qualified *) - let e_uid = Name_resolution.get_enum ctxt enum in - try - let c_uid = EnumName.Map.find e_uid possible_c_uids in - e_uid, c_uid - with EnumName.Map.Not_found _ -> - Message.raise_spanned_error pos "Enum %s does not contain case %s" - (Mark.remove enum) (Mark.remove constructor)) - | (modname, mpos)::path -> + (* The path is fully qualified *) + let e_uid = Name_resolution.get_enum ctxt enum in + try + let c_uid = EnumName.Map.find e_uid possible_c_uids in + e_uid, c_uid + with EnumName.Map.Not_found _ -> + Message.raise_spanned_error pos "Enum %s does not contain case %s" + (Mark.remove enum) (Mark.remove constructor)) + | (modname, mpos) :: path -> ( match ModuleName.Map.find_opt modname ctxt.modules with | None -> - Message.raise_spanned_error mpos "Module %a not found" ModuleName.format modname + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format + modname | Some ctxt -> - let constructor = List.map (Mark.map (fun (_, c) -> path, c)) constructor0 in - disambiguate_constructor ctxt constructor pos + let constructor = + List.map (Mark.map (fun (_, c) -> path, c)) constructor0 + in + disambiguate_constructor ctxt constructor pos) let int100 = Runtime.integer_of_int 100 let rat100 = Runtime.decimal_of_integer int100 @@ -213,7 +216,7 @@ let rec translate_expr | None -> Ident.Map.empty | Some s -> (ScopeName.Map.find s ctxt.scopes).var_idmap in - let rec_helper ?(local_vars=local_vars) e = + let rec_helper ?(local_vars = local_vars) e = translate_expr scope inside_definition_of ctxt local_vars e in let pos = Mark.get expr in @@ -240,14 +243,14 @@ let rec translate_expr [tau] pos else let binding_var = Var.make (Mark.remove binding) in - let local_vars = Ident.Map.add (Mark.remove binding) binding_var local_vars in + let local_vars = + Ident.Map.add (Mark.remove binding) binding_var local_vars + in let e2 = rec_helper ~local_vars e2 in Expr.make_abs [| binding_var |] e2 [tau] pos) (EnumName.Map.find enum_uid ctxt.enums) in - Expr.ematch - ~e:(rec_helper e1_sub) - ~name:enum_uid ~cases emark + Expr.ematch ~e:(rec_helper e1_sub) ~name:enum_uid ~cases emark | Binop ((((S.And | S.Or | S.Xor), _) as op), e1, e2) -> check_formula op e1; check_formula op e2; @@ -349,7 +352,8 @@ let rec translate_expr with respect to the state that we are defining. *) let rec find_prev_state = function | [] -> None - | st0 :: st1 :: _ when StateName.equal inside_def_state st1 -> + | st0 :: st1 :: _ when StateName.equal inside_def_state st1 + -> Some st0 | _ :: states -> find_prev_state states in @@ -358,7 +362,9 @@ let rec translate_expr (* we take the last state in the chain *) Some (List.hd (List.rev states))) in - Expr.elocation (DesugaredScopeVar { name = uid, pos; state = x_state }) emark + Expr.elocation + (DesugaredScopeVar { name = uid, pos; state = x_state }) + emark | Some (SubScope _) (* Note: allowing access to a global variable with the same name as a subscope is disputable, but I see no good reason to forbid it either *) @@ -366,21 +372,21 @@ let rec translate_expr match Ident.Map.find_opt x ctxt.topdefs with | Some v -> Expr.elocation - (ToplevelVar { path = []; name = v, Mark.get (TopdefName.get_info v) }) + (ToplevelVar + { path = []; name = v, Mark.get (TopdefName.get_info v) }) emark | None -> Name_resolution.raise_unknown_identifier "for a local, scope-wide or global variable" (x, pos)))) - | Ident (path, name) -> + | Ident (path, name) -> ( let ctxt = Name_resolution.module_ctx ctxt path in - (match Ident.Map.find_opt (Mark.remove name) ctxt.topdefs with - | Some v -> - Expr.elocation - (ToplevelVar { path; name = v, Mark.get (TopdefName.get_info v) }) - emark - | None -> - Name_resolution.raise_unknown_identifier - "for an external variable" name) + match Ident.Map.find_opt (Mark.remove name) ctxt.topdefs with + | Some v -> + Expr.elocation + (ToplevelVar { path; name = v, Mark.get (TopdefName.get_info v) }) + emark + | None -> + Name_resolution.raise_unknown_identifier "for an external variable" name) | Dotted (e, ((path, x), _ppos)) -> ( match path, Mark.remove e with | [], Ident ([], (y, _)) @@ -397,29 +403,29 @@ let rec translate_expr Name_resolution.get_var_uid subscope_real_uid ctxt x in Expr.elocation - (SubScopeVar { - path = subscope_path; - scope = subscope_real_uid; - alias = (subscope_uid, pos); - var = (subscope_var_uid, pos) - }) + (SubScopeVar + { + path = subscope_path; + scope = subscope_real_uid; + alias = subscope_uid, pos; + var = subscope_var_uid, pos; + }) emark | _ -> (* In this case e.x is the struct field x access of expression e *) let e = rec_helper e in let rec get_str ctxt = function | [] -> None - | [c] -> - Some (Name_resolution.get_struct ctxt c) - | (modname, mpos) :: path -> + | [c] -> Some (Name_resolution.get_struct ctxt c) + | (modname, mpos) :: path -> ( match ModuleName.Map.find_opt modname ctxt.modules with | None -> - Message.raise_spanned_error mpos - "Module %a not found" ModuleName.format modname - | Some ctxt -> - get_str ctxt path + Message.raise_spanned_error mpos "Module %a not found" + ModuleName.format modname + | Some ctxt -> get_str ctxt path) in - Expr.edstructaccess ~e ~field:(Mark.remove x) ~name_opt:(get_str ctxt path) ~path emark) + Expr.edstructaccess ~e ~field:(Mark.remove x) + ~name_opt:(get_str ctxt path) ~path emark) | FunCall (f, args) -> Expr.eapp (rec_helper f) (List.map rec_helper args) emark | ScopeCall (((path, id), _), fields) -> @@ -467,11 +473,7 @@ let rec translate_expr let local_vars = Ident.Map.add (Mark.remove x) v local_vars in let tau = TAny, Mark.get x in (* This type will be resolved in Scopelang.Desambiguation *) - let fn = - Expr.make_abs [| v |] - (rec_helper ~local_vars e2) - [tau] pos - in + let fn = Expr.make_abs [| v |] (rec_helper ~local_vars e2) [tau] pos in Expr.eapp fn [rec_helper e1] emark | StructLit ((([], s_name), _), fields) -> let s_uid = @@ -540,38 +542,38 @@ let rec translate_expr let e_uid, c_uid = EnumName.Map.choose possible_c_uids in let payload = Option.map rec_helper payload in Expr.einj - ~e:(match payload with - | Some e' -> e' - | None -> Expr.elit LUnit mark_constructor) + ~e: + (match payload with + | Some e' -> e' + | None -> Expr.elit LUnit mark_constructor) ~cons:c_uid ~name:e_uid emark | path_enum -> ( - let path, enum = match List.rev path_enum with + let path, enum = + match List.rev path_enum with | enum :: rpath -> List.rev rpath, enum | _ -> assert false in - let ctxt = Name_resolution.module_ctx ctxt path in - let possible_c_uids = get_possible_c_uids ctxt in - (* The path has been qualified *) - let e_uid = Name_resolution.get_enum ctxt enum in - try - let c_uid = EnumName.Map.find e_uid possible_c_uids in - let payload = - Option.map rec_helper payload - in - Expr.einj - ~e:(match payload with + let ctxt = Name_resolution.module_ctx ctxt path in + let possible_c_uids = get_possible_c_uids ctxt in + (* The path has been qualified *) + let e_uid = Name_resolution.get_enum ctxt enum in + try + let c_uid = EnumName.Map.find e_uid possible_c_uids in + let payload = Option.map rec_helper payload in + Expr.einj + ~e: + (match payload with | Some e' -> e' | None -> Expr.elit LUnit mark_constructor) - ~cons:c_uid ~name:e_uid emark - with EnumName.Map.Not_found _ -> - Message.raise_spanned_error pos "Enum %s does not contain case %s" - (Mark.remove enum) constructor)) + ~cons:c_uid ~name:e_uid emark + with EnumName.Map.Not_found _ -> + Message.raise_spanned_error pos "Enum %s does not contain case %s" + (Mark.remove enum) constructor)) | MatchWith (e1, (cases, _cases_pos)) -> let e1 = rec_helper e1 in let cases_d, e_uid = disambiguate_match_and_build_expression scope inside_definition_of ctxt - local_vars - cases + local_vars cases in Expr.ematch ~e:e1 ~name:e_uid ~cases:cases_d emark | TestMatchCase (e1, pattern) -> @@ -594,9 +596,7 @@ let rec translate_expr [tau] pos) (EnumName.Map.find enum_uid ctxt.enums) in - Expr.ematch - ~e:(rec_helper e1) - ~name:enum_uid ~cases:cases emark + Expr.ematch ~e:(rec_helper e1) ~name:enum_uid ~cases emark | ArrayLit es -> Expr.earray (List.map rec_helper es) emark | CollectionOp (((S.Filter { f } | S.Map { f }) as op), collection) -> let collection = rec_helper collection in @@ -619,8 +619,8 @@ let rec translate_expr emark) [f_pred; collection] emark | CollectionOp - (S.AggregateArgExtremum { max; default; f = param_name, predicate }, collection) - -> + ( S.AggregateArgExtremum { max; default; f = param_name, predicate }, + collection ) -> let default = rec_helper default in let pos_dft = Expr.pos default in let collection = rec_helper collection in @@ -800,9 +800,7 @@ and disambiguate_match_and_build_expression let bind_match_cases (cases_d, e_uid, curr_index) (case, case_pos) = match case with | S.MatchCase case -> - let constructor, binding = - Mark.remove case.S.match_case_pattern - in + let constructor, binding = Mark.remove case.S.match_case_pattern in let e_uid', c_uid = disambiguate_constructor ctxt constructor (Mark.get case.S.match_case_pattern) @@ -826,7 +824,9 @@ and disambiguate_match_and_build_expression [None, Mark.get case.match_case_expr; None, Expr.pos e_case] "The constructor %a has been matched twice:" EnumConstructor.format c_uid); - let local_vars, param_var = create_var local_vars (Option.map Mark.remove binding) in + let local_vars, param_var = + create_var local_vars (Option.map Mark.remove binding) + in let case_body = translate_expr scope inside_definition_of ctxt local_vars case.S.match_case_expr @@ -882,7 +882,8 @@ and disambiguate_match_and_build_expression (* Creates the wildcard payload *) let local_vars, payload_var = create_var local_vars None in let case_body = - translate_expr scope inside_definition_of ctxt local_vars match_case_expr + translate_expr scope inside_definition_of ctxt local_vars + match_case_expr in let e_binder = Expr.bind [| payload_var |] case_body in @@ -972,8 +973,7 @@ let process_rule_parameters Message.raise_multispanned_error [ Some "Arguments declared here", pos; - ( Some "Definition missing the arguments", - Mark.get def.S.definition_name ); + Some "Definition missing the arguments", Mark.get def.S.definition_name; ] "This definition for %a is missing the arguments" Ast.ScopeDef.format decl_name @@ -982,9 +982,9 @@ let process_rule_parameters let local_vars, params = List.fold_left_map (fun local_vars ((lbl, pos), ty) -> - let v = Var.make lbl in - let local_vars = Ident.Map.add lbl v local_vars in - local_vars, ((v, pos), ty)) + let v = Var.make lbl in + let local_vars = Ident.Map.add lbl v local_vars in + local_vars, ((v, pos), ty)) Ident.Map.empty pdecl in local_vars, Some (params, pos_def) @@ -1005,7 +1005,8 @@ let process_default (cons : S.expression) : Ast.rule = let just = match just with - | Some just -> Some (translate_expr (Some scope) (Some def_key) ctxt local_vars just) + | Some just -> + Some (translate_expr (Some scope) (Some def_key) ctxt local_vars just) | None -> None in let just = merge_conditions precond just (Mark.get def_key) in @@ -1159,7 +1160,9 @@ let process_scope_use_item (ctxt : Name_resolution.context) (prgm : Ast.program) (item : S.scope_use_item Mark.pos) : Ast.program = - let precond = Option.map (translate_expr (Some scope) None ctxt Ident.Map.empty) precond in + let precond = + Option.map (translate_expr (Some scope) None ctxt Ident.Map.empty) precond + in match Mark.remove item with | S.Rule rule -> process_rule precond scope ctxt prgm rule | S.Definition def -> process_def precond scope ctxt prgm def @@ -1277,14 +1280,15 @@ let process_topdef let expr_opt = match def.S.topdef_expr, def.S.topdef_args with | None, _ -> None - | Some e, None -> Some (Expr.unbox_closed (translate_expr None None ctxt Ident.Map.empty e)) + | Some e, None -> + Some (Expr.unbox_closed (translate_expr None None ctxt Ident.Map.empty e)) | Some e, Some (args, _) -> let local_vars, args_tys = List.fold_left_map (fun local_vars ((lbl, pos), ty) -> - let v = Var.make lbl in - let local_vars = Ident.Map.add lbl v local_vars in - local_vars, ((v, pos), ty)) + let v = Var.make lbl in + let local_vars = Ident.Map.add lbl v local_vars in + local_vars, ((v, pos), ty)) Ident.Map.empty args in let body = translate_expr None None ctxt local_vars e in @@ -1417,9 +1421,8 @@ let init_scope_defs Ident.Map.fold add_def scope_idmap Ast.ScopeDef.Map.empty (** Main function of this module *) -let translate_program - (ctxt : Name_resolution.context) - (surface : S.program) : Ast.program = +let translate_program (ctxt : Name_resolution.context) (surface : S.program) : + Ast.program = let desugared = let get_program_scopes ctxt = ScopeName.Map.mapi @@ -1430,7 +1433,9 @@ let translate_program match v with | Name_resolution.SubScope _ -> acc | Name_resolution.ScopeVar v -> ( - let v_sig = ScopeVar.Map.find v ctxt.Name_resolution.var_typs in + let v_sig = + ScopeVar.Map.find v ctxt.Name_resolution.var_typs + in match v_sig.Name_resolution.var_sig_states_list with | [] -> ScopeVar.Map.add v Ast.WholeVar acc | states -> ScopeVar.Map.add v (Ast.States states) acc)) @@ -1458,39 +1463,52 @@ let translate_program in let rec make_ctx ctxt = let submodules = - ModuleName.Map.map make_ctx ctxt.Name_resolution.modules; + ModuleName.Map.map make_ctx ctxt.Name_resolution.modules in { Ast.program_ctx = { - (* After name resolution, type definitions (structs and enums) are exposed at toplevel for easier lookup, but their paths need to remain available for printing and later passes *) + (* After name resolution, type definitions (structs and enums) are + exposed at toplevel for easier lookup, but their paths need to + remain available for printing and later passes *) ctx_structs = - ModuleName.Map.fold (fun modname prg acc -> - StructName.Map.union (fun _ _ _ -> assert false) acc + ModuleName.Map.fold + (fun modname prg acc -> + StructName.Map.union + (fun _ _ _ -> assert false) + acc (StructName.Map.map (fun (path, def) -> (modname, Pos.no_pos) :: path, def) prg.Ast.program_ctx.ctx_structs)) submodules - (StructName.Map.map (fun def -> [], def) ctxt.Name_resolution.structs); + (StructName.Map.map + (fun def -> [], def) + ctxt.Name_resolution.structs); ctx_enums = - ModuleName.Map.fold (fun modname prg acc -> - EnumName.Map.union (fun _ _ _ -> assert false) acc + ModuleName.Map.fold + (fun modname prg acc -> + EnumName.Map.union + (fun _ _ _ -> assert false) + acc (EnumName.Map.map (fun (path, def) -> (modname, Pos.no_pos) :: path, def) prg.Ast.program_ctx.ctx_enums)) submodules - (EnumName.Map.map (fun def -> [], def) ctxt.Name_resolution.enums); + (EnumName.Map.map + (fun def -> [], def) + ctxt.Name_resolution.enums); ctx_scopes = Ident.Map.fold (fun _ def acc -> - match def with - | Name_resolution.TScope (scope, scope_info) -> - ScopeName.Map.add scope scope_info acc - | _ -> acc) + match def with + | Name_resolution.TScope (scope, scope_info) -> + ScopeName.Map.add scope scope_info acc + | _ -> acc) ctxt.Name_resolution.typedefs ScopeName.Map.empty; ctx_struct_fields = ctxt.Name_resolution.field_idmap; ctx_topdefs = ctxt.Name_resolution.topdef_types; - ctx_modules = ModuleName.Map.map (fun s -> s.Ast.program_ctx) submodules; + ctx_modules = + ModuleName.Map.map (fun s -> s.Ast.program_ctx) submodules; }; Ast.program_topdefs = TopdefName.Map.empty; Ast.program_scopes = get_program_scopes ctxt; @@ -1502,44 +1520,52 @@ let translate_program let process_code_block ctxt prgm block = List.fold_left (fun prgm item -> - match Mark.remove item with - | S.ScopeUse use -> process_scope_use ctxt prgm use - | S.Topdef def -> process_topdef ctxt prgm def - | S.ScopeDecl _ | S.StructDecl _ - | S.EnumDecl _ -> - prgm) + match Mark.remove item with + | S.ScopeUse use -> process_scope_use ctxt prgm use + | S.Topdef def -> process_topdef ctxt prgm def + | S.ScopeDecl _ | S.StructDecl _ | S.EnumDecl _ -> prgm) prgm block in - let rec process_structure - (prgm : Ast.program) - (item : S.law_structure) : Ast.program = + let rec process_structure (prgm : Ast.program) (item : S.law_structure) : + Ast.program = match item with | S.LawHeading (_, children) -> List.fold_left (fun prgm child -> process_structure prgm child) prgm children - | S.CodeBlock (block, _, _) -> - process_code_block ctxt prgm block + | S.CodeBlock (block, _, _) -> process_code_block ctxt prgm block | S.LawInclude _ | S.LawText _ -> prgm in Message.emit_debug "DESUGARED → prog scopes: %a@ modules: %a" - (ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space) desugared.Ast.program_scopes - (ModuleName.Map.format - (fun fmt prg -> ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space fmt prg.Ast.program_scopes)) desugared.Ast.program_modules; + (ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space) + desugared.Ast.program_scopes + (ModuleName.Map.format (fun fmt prg -> + ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space fmt + prg.Ast.program_scopes)) + desugared.Ast.program_modules; let desugared = - List.fold_left (fun acc (id, intf) -> + List.fold_left + (fun acc (id, intf) -> let modul = ModuleName.Map.find id acc.Ast.program_modules in - let modul = process_code_block (Name_resolution.module_ctx ctxt [id, Pos.no_pos]) modul intf in - { acc with program_modules = - ModuleName.Map.add id modul acc.program_modules }) - desugared - surface.S.program_modules + let modul = + process_code_block + (Name_resolution.module_ctx ctxt [id, Pos.no_pos]) + modul intf + in + { + acc with + program_modules = ModuleName.Map.add id modul acc.program_modules; + }) + desugared surface.S.program_modules in let desugared = List.fold_left process_structure desugared surface.S.program_items in Message.emit_debug "DESUGARED2 → prog scopes: %a@ modules: %a" - (ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space) desugared.Ast.program_scopes - (ModuleName.Map.format - (fun fmt prg -> ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space fmt prg.Ast.program_scopes)) desugared.Ast.program_modules; + (ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space) + desugared.Ast.program_scopes + (ModuleName.Map.format (fun fmt prg -> + ScopeName.Map.format_keys ~pp_sep:Format.pp_print_space fmt + prg.Ast.program_scopes)) + desugared.Ast.program_modules; desugared diff --git a/compiler/desugared/linting.ml b/compiler/desugared/linting.ml index 68e1d03f..a4d9b2af 100644 --- a/compiler/desugared/linting.ml +++ b/compiler/desugared/linting.ml @@ -108,7 +108,8 @@ let detect_unused_struct_fields (p : program) : unit = ~f:(fun struct_fields_used e -> let rec structs_fields_used_expr e struct_fields_used = match Mark.remove e with - | EDStructAccess { name_opt = Some name; e = e_struct; field; path = _ } -> + | EDStructAccess + { name_opt = Some name; e = e_struct; field; path = _ } -> let field = StructName.Map.find name (Ident.Map.find field p.program_ctx.ctx_struct_fields) diff --git a/compiler/desugared/name_resolution.ml b/compiler/desugared/name_resolution.ml index 549d5eea..bea4c5d4 100644 --- a/compiler/desugared/name_resolution.ml +++ b/compiler/desugared/name_resolution.ml @@ -65,8 +65,7 @@ type var_sig = { type typedef = | TStruct of StructName.t | TEnum of EnumName.t - | TScope of ScopeName.t * scope_info - (** Implicitly defined output struct *) + | TScope of ScopeName.t * scope_info (** Implicitly defined output struct *) type context = { typedefs : typedef Ident.Map.t; @@ -238,15 +237,15 @@ let get_scope ctxt id = Message.raise_spanned_error (Mark.get id) "No scope named %s found" (Mark.remove id) -let rec module_ctx ctxt path = match path with +let rec module_ctx ctxt path = + match path with | [] -> ctxt - | (modname, mpos) :: path -> - (match ModuleName.Map.find_opt modname ctxt.modules with - | None -> - Message.raise_spanned_error mpos - "Module %a not found" ModuleName.format modname - | Some ctxt -> - module_ctx ctxt path) + | (modname, mpos) :: path -> ( + match ModuleName.Map.find_opt modname ctxt.modules with + | None -> + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format + modname + | Some ctxt -> module_ctx ctxt path) (** {1 Declarations pass} *) @@ -267,8 +266,7 @@ let process_subscope_decl in Message.raise_multispanned_error [Some "first use", Mark.get info; Some "second use", s_pos] - "Subscope name @{\"%s\"@} already used" - (Mark.remove subscope) + "Subscope name @{\"%s\"@} already used" (Mark.remove subscope) | None -> let sub_scope_uid = SubScopeName.fresh (name, name_pos) in let original_subscope_uid = @@ -316,23 +314,24 @@ let rec process_base_typ | Surface.Ast.Text -> raise_unsupported_feature "text type" typ_pos | Surface.Ast.Named ([], (ident, _pos)) -> ( match Ident.Map.find_opt ident ctxt.typedefs with - | Some (TStruct s_uid) -> TStruct ( s_uid), typ_pos - | Some (TEnum e_uid) -> TEnum ( e_uid), typ_pos + | Some (TStruct s_uid) -> TStruct s_uid, typ_pos + | Some (TEnum e_uid) -> TEnum e_uid, typ_pos | Some (TScope (_, scope_str)) -> - TStruct ( scope_str.out_struct_name), typ_pos + TStruct scope_str.out_struct_name, typ_pos | None -> Message.raise_spanned_error typ_pos "Unknown type @{\"%s\"@}, not a struct or enum previously \ declared" ident) - | Surface.Ast.Named ((modul, mpos)::path, id) -> + | Surface.Ast.Named ((modul, mpos) :: path, id) -> ( match ModuleName.Map.find_opt modul ctxt.modules with | None -> Message.raise_spanned_error mpos - "This refers to module %a, which was not found" - ModuleName.format modul + "This refers to module %a, which was not found" ModuleName.format + modul | Some mod_ctxt -> - process_base_typ mod_ctxt Surface.Ast.(Data (Primitive (Named (path, id))), typ_pos)) + process_base_typ mod_ctxt + Surface.Ast.(Data (Primitive (Named (path, id))), typ_pos))) (** Process a type (function or not) *) let process_type (ctxt : context) ((naked_typ, typ_pos) : Surface.Ast.typ) : typ @@ -589,7 +588,9 @@ let process_scope_decl (ctxt : context) (decl : Surface.Ast.scope_decl) : (Mark.remove decl.scope_decl_name) (function | Some (TScope (scope, { in_struct_name; out_struct_name; _ })) -> - Some (TScope (scope, { in_struct_name; out_struct_name; out_struct_fields; })) + Some + (TScope + (scope, { in_struct_name; out_struct_name; out_struct_fields })) | _ -> assert false) ctxt.typedefs in @@ -681,9 +682,14 @@ let process_name_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : "toplevel definition") (Ident.Map.find_opt name ctxt.topdefs); let uid = TopdefName.fresh def.topdef_name in - { ctxt with + { + ctxt with topdefs = Ident.Map.add name uid ctxt.topdefs; - topdef_types = TopdefName.Map.add uid (process_type ctxt def.topdef_type) ctxt.topdef_types } + topdef_types = + TopdefName.Map.add uid + (process_type ctxt def.topdef_type) + ctxt.topdef_types; + } (** Process a code item that is a declaration *) let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : @@ -699,16 +705,14 @@ let process_decl_item (ctxt : context) (item : Surface.Ast.code_item Mark.pos) : let process_code_block (process_item : context -> Surface.Ast.code_item Mark.pos -> context) (ctxt : context) - (block : Surface.Ast.code_block) : - context = + (block : Surface.Ast.code_block) : context = List.fold_left (fun ctxt decl -> process_item ctxt decl) ctxt block (** Process a law structure, only considering the code blocks *) let rec process_law_structure (process_item : context -> Surface.Ast.code_item Mark.pos -> context) (ctxt : context) - (s : Surface.Ast.law_structure) : - context = + (s : Surface.Ast.law_structure) : context = match s with | Surface.Ast.LawHeading (_, children) -> List.fold_left @@ -758,7 +762,8 @@ let get_def_key ScopeVar.format x_uid else None ) | [y; x] -> - let (subscope_uid, (path, subscope_real_uid)) : SubScopeName.t * (path * ScopeName.t) = + let (subscope_uid, (path, subscope_real_uid)) + : SubScopeName.t * (path * ScopeName.t) = match Ident.Map.find_opt (Mark.remove y) scope_ctxt.var_idmap with | Some (SubScope (v, u)) -> v, u | Some _ -> @@ -933,14 +938,11 @@ let empty_ctxt = let import_module modules (name, intf) = let ctxt = { empty_ctxt with modules } in - let ctxt = - List.fold_left process_name_item ctxt intf - in - let ctxt = - List.fold_left process_decl_item ctxt intf - in + let ctxt = List.fold_left process_name_item ctxt intf in + let ctxt = List.fold_left process_decl_item ctxt intf in let ctxt = { ctxt with modules = empty_ctxt.modules } in - (* No submodules at the moment, a module may use the ones loaded before it, but doesn't reexport them *) + (* No submodules at the moment, a module may use the ones loaded before it, + but doesn't reexport them *) ModuleName.Map.add name ctxt modules (** Derive the context from metadata, in one pass over the declarations *) @@ -950,8 +952,10 @@ let form_context (prgm : Surface.Ast.program) : context = in let ctxt = { empty_ctxt with modules } in let rec gather_var_sigs acc modules = - (* Scope vars from imported modules need to be accessible directly for definitions through submodules *) - ModuleName.Map.fold (fun _modname mctx acc -> + (* Scope vars from imported modules need to be accessible directly for + definitions through submodules *) + ModuleName.Map.fold + (fun _modname mctx acc -> let acc = gather_var_sigs acc mctx.modules in ScopeVar.Map.union (fun _ _ -> assert false) acc mctx.var_typs) modules acc diff --git a/compiler/desugared/name_resolution.mli b/compiler/desugared/name_resolution.mli index f96b069c..2fbe495e 100644 --- a/compiler/desugared/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -65,8 +65,7 @@ type var_sig = { type typedef = | TStruct of StructName.t | TEnum of EnumName.t - | TScope of ScopeName.t * scope_info - (** Implicitly defined output struct *) + | TScope of ScopeName.t * scope_info (** Implicitly defined output struct *) type context = { typedefs : typedef Ident.Map.t; @@ -152,7 +151,8 @@ val get_scope : context -> Ident.t Mark.pos -> ScopeName.t has a different kind *) val module_ctx : context -> path -> context -(** Returns the context corresponding to the given module path; raises a user error if the module is not found *) +(** Returns the context corresponding to the given module path; raises a user + error if the module is not found *) val process_type : context -> Surface.Ast.typ -> typ (** Convert a surface base type to an AST type *) diff --git a/compiler/driver.ml b/compiler/driver.ml index 38ab88d8..7258a5cd 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -61,9 +61,7 @@ module Passes = struct let prg = Surface.Parser_driver.parse_top_level_file options.input_file language in - let prg = - Surface.Fill_positions.fill_pos_with_legislative_info prg - in + let prg = Surface.Fill_positions.fill_pos_with_legislative_info prg in let prg = { prg with program_modules = load_module_interfaces options link_modules } in @@ -256,8 +254,8 @@ module Commands = struct "Variable @{\"%s\"@} not found inside scope @{\"%a\"@}" variable ScopeName.format scope_uid | Some - (Desugared.Name_resolution.SubScope (subscope_var_name, (subscope_path, subscope_name))) - -> ( + (Desugared.Name_resolution.SubScope + (subscope_var_name, (subscope_path, subscope_name))) -> ( match second_part with | None -> Message.raise_error diff --git a/compiler/lcalc/ast.ml b/compiler/lcalc/ast.ml index a6c07258..d4457546 100644 --- a/compiler/lcalc/ast.ml +++ b/compiler/lcalc/ast.ml @@ -26,7 +26,8 @@ module OptionMonad = struct Expr.einj ~e ~cons:Expr.some_constr ~name:Expr.option_enum mark let empty ~(mark : 'a mark) = - Expr.einj ~e:(Expr.elit LUnit mark) ~cons:Expr.none_constr ~name:Expr.option_enum mark + Expr.einj ~e:(Expr.elit LUnit mark) ~cons:Expr.none_constr + ~name:Expr.option_enum mark let bind_var ~(mark : 'a mark) f x arg = let cases = @@ -36,8 +37,8 @@ module OptionMonad = struct let x = Var.make "_" in Expr.eabs (Expr.bind [| x |] - (Expr.einj ~e:(Expr.evar x mark) ~cons:Expr.none_constr ~name:Expr.option_enum - mark)) + (Expr.einj ~e:(Expr.evar x mark) ~cons:Expr.none_constr + ~name:Expr.option_enum mark)) [TLit TUnit, Expr.mark_pos mark] mark ); (* | None x -> None x *) diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index 40f88e06..3d37ef29 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -169,7 +169,11 @@ let rec trans (ctx : typed ctx) (e : typed D.expr) : (lcalc, typed) boxed_gexpr Ast.OptionMonad.return ~mark (Expr.eapp (Expr.evar (trans_var ctx scope) mark) - [Expr.estruct ~name ~fields:(StructField.Map.map (trans ctx) fields) mark] + [ + Expr.estruct ~name + ~fields:(StructField.Map.map (trans ctx) fields) + mark; + ] mark) | EApp { f = (EVar ff, _) as f; args } when not (Var.Map.find ff ctx.ctx_vars).is_scope -> @@ -750,8 +754,7 @@ let translate_program (prgm : typed D.program) : untyped A.program = ctx_structs = prgm.decl_ctx.ctx_structs |> StructName.Map.mapi (fun _n (path, str) -> - path, - StructField.Map.map trans_typ_keep str); + path, StructField.Map.map trans_typ_keep str); } in diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 13366e40..19ead856 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -218,9 +218,7 @@ let rec format_typ (fmt : Format.formatter) (typ : typ) : unit = | TClosureEnv -> failwith "unimplemented!" let format_var_str (fmt : Format.formatter) (v : string) : unit = - let lowercase_name = - String.to_snake_case (String.to_ascii v) - in + let lowercase_name = String.to_snake_case (String.to_ascii v) in let lowercase_name = Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") @@ -276,13 +274,21 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : in match Mark.remove e with | EVar v -> Format.fprintf fmt "%a" format_var v - | EExternal { path; name } -> + | EExternal { path; name } -> ( Print.path fmt path; - (* FIXME: this is wrong in general !! - We assume the idents exposed by the module depend only on the original name, while they actually get through Bindlib and may have been renamed. A correct implem could use the runtime registration used by the interpreter, but that would be distasteful and incur a penalty ; or we would need to reproduce the same structure as in the original module to ensure that bindlib performs the exact same renamings ; or finally we could normalise the names at generation time (either at toplevel or in a dedicated submodule ?) *) - (match Mark.remove name with - | External_value name -> format_var_str fmt (Mark.remove (TopdefName.get_info name)) - | External_scope name -> format_var_str fmt (Mark.remove (ScopeName.get_info name))) + (* FIXME: this is wrong in general !! We assume the idents exposed by the + module depend only on the original name, while they actually get through + Bindlib and may have been renamed. A correct implem could use the runtime + registration used by the interpreter, but that would be distasteful and + incur a penalty ; or we would need to reproduce the same structure as in + the original module to ensure that bindlib performs the exact same + renamings ; or finally we could normalise the names at generation time + (either at toplevel or in a dedicated submodule ?) *) + match Mark.remove name with + | External_value name -> + format_var_str fmt (Mark.remove (TopdefName.get_info name)) + | External_scope name -> + format_var_str fmt (Mark.remove (ScopeName.get_info name))) | ETuple es -> Format.fprintf fmt "@[(%a)@]" (Format.pp_print_list @@ -550,12 +556,10 @@ let format_ctx match struct_or_enum with | Scopelang.Dependency.TVertex.Struct s -> let path, def = StructName.Map.find s ctx.ctx_structs in - if path = [] then - Format.fprintf fmt "%a@\n" format_struct_decl (s, def) + if path = [] then Format.fprintf fmt "%a@\n" format_struct_decl (s, def) | Scopelang.Dependency.TVertex.Enum e -> let path, def = EnumName.Map.find e ctx.ctx_enums in - if path = [] then - Format.fprintf fmt "%a@\n" format_enum_decl (e, def)) + if path = [] then Format.fprintf fmt "%a@\n" format_enum_decl (e, def)) (type_ordering @ scope_structs) let rename_vars e = diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index 4cd232fe..5fc2589a 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -146,53 +146,53 @@ module To_jsoo = struct To_ocaml.format_to_module_name fmt (`Sname struct_name) in let fmt_to_jsoo fmt _ = - Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") - (fun fmt (struct_field, struct_field_type) -> - match Mark.remove struct_field_type with - | TArrow (t1, t2) -> - let args_names = - ListLabels.mapi t1 ~f:(fun i _ -> - "function_input" ^ string_of_int i) - in - Format.fprintf fmt - "@[method %a =@ Js.wrap_meth_callback@ @[(@,\ - fun _ %a ->@ %a (%a.%a %a))@]@]" - format_struct_field_name_camel_case struct_field - (Format.pp_print_list (fun fmt (arg_i, ti) -> - Format.fprintf fmt "(%s: %a)" arg_i format_typ ti)) - (List.combine args_names t1) - format_typ_to_jsoo t2 fmt_struct_name () - format_struct_field_name (None, struct_field) - (Format.pp_print_list (fun fmt (i, ti) -> - Format.fprintf fmt "@[(%a@ %a)@]" - format_typ_of_jsoo ti Format.pp_print_string i)) - (List.combine args_names t1) - | _ -> - Format.fprintf fmt "@[val %a =@ %a %a.%a@]" - format_struct_field_name_camel_case struct_field - format_typ_to_jsoo struct_field_type fmt_struct_name () - format_struct_field_name (None, struct_field)) - fmt + Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") + (fun fmt (struct_field, struct_field_type) -> + match Mark.remove struct_field_type with + | TArrow (t1, t2) -> + let args_names = + ListLabels.mapi t1 ~f:(fun i _ -> + "function_input" ^ string_of_int i) + in + Format.fprintf fmt + "@[method %a =@ Js.wrap_meth_callback@ @[(@,\ + fun _ %a ->@ %a (%a.%a %a))@]@]" + format_struct_field_name_camel_case struct_field + (Format.pp_print_list (fun fmt (arg_i, ti) -> + Format.fprintf fmt "(%s: %a)" arg_i format_typ ti)) + (List.combine args_names t1) + format_typ_to_jsoo t2 fmt_struct_name () + format_struct_field_name (None, struct_field) + (Format.pp_print_list (fun fmt (i, ti) -> + Format.fprintf fmt "@[(%a@ %a)@]" format_typ_of_jsoo + ti Format.pp_print_string i)) + (List.combine args_names t1) + | _ -> + Format.fprintf fmt "@[val %a =@ %a %a.%a@]" + format_struct_field_name_camel_case struct_field + format_typ_to_jsoo struct_field_type fmt_struct_name () + format_struct_field_name (None, struct_field)) + fmt (StructField.Map.bindings struct_fields) in let fmt_of_jsoo fmt _ = - Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") - (fun fmt (struct_field, struct_field_type) -> - match Mark.remove struct_field_type with - | TArrow _ -> - Format.fprintf fmt - "%a = failwith \"The function '%a' translation isn't yet \ - supported...\"" - format_struct_field_name (None, struct_field) - format_struct_field_name (None, struct_field) - | _ -> - Format.fprintf fmt - "@[%a =@ @[%a@ @[%a@,##.%a@]@]@]" - format_struct_field_name (None, struct_field) - format_typ_of_jsoo struct_field_type fmt_struct_name () - format_struct_field_name_camel_case struct_field) + Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@\n") + (fun fmt (struct_field, struct_field_type) -> + match Mark.remove struct_field_type with + | TArrow _ -> + Format.fprintf fmt + "%a = failwith \"The function '%a' translation isn't yet \ + supported...\"" + format_struct_field_name (None, struct_field) + format_struct_field_name (None, struct_field) + | _ -> + Format.fprintf fmt + "@[%a =@ @[%a@ @[%a@,##.%a@]@]@]" + format_struct_field_name (None, struct_field) format_typ_of_jsoo + struct_field_type fmt_struct_name () + format_struct_field_name_camel_case struct_field) fmt (StructField.Map.bindings struct_fields) in @@ -231,8 +231,9 @@ module To_jsoo = struct (StructField.Map.bindings struct_fields) fmt_conv_funs () in - let format_enum_decl fmt (enum_name, (path, (enum_cons : typ EnumConstructor.Map.t))) - = + let format_enum_decl + fmt + (enum_name, (path, (enum_cons : typ EnumConstructor.Map.t))) = let fmt_enum_name fmt _ = format_enum_name fmt enum_name in let fmt_module_enum_name fmt () = Print.path fmt path; diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index 8bb108b9..f853881a 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -80,8 +80,7 @@ module To_json = struct Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@\n") (fun fmt (field_name, field_type) -> - Format.fprintf fmt "@[\"%a%a\": {@\n%a@]@\n}" - Print.path path + Format.fprintf fmt "@[\"%a%a\": {@\n%a@]@\n}" Print.path path format_struct_field_name_camel_case field_name fmt_type field_type) fmt (StructField.Map.bindings fields) @@ -105,7 +104,8 @@ module To_json = struct (t :: acc) @ collect_required_type_defs_from_scope_input s | TEnum e -> List.fold_left collect (t :: acc) - (EnumConstructor.Map.values (snd (EnumName.Map.find e ctx.ctx_enums))) + (EnumConstructor.Map.values + (snd (EnumName.Map.find e ctx.ctx_enums))) | TArray t -> collect acc t | _ -> acc in diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml index 958ff8ed..b5a0d3bb 100644 --- a/compiler/plugins/lazy_interp.ml +++ b/compiler/plugins/lazy_interp.ml @@ -234,15 +234,16 @@ let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : let m = Mark.get e in let application_arg = Expr.estruct ~name:scope_arg_struct - ~fields:(StructField.Map.map - (function - | TArrow (ty_in, ty_out), _ -> - Expr.make_abs - [| Var.make "_" |] - (Bindlib.box EEmptyError, Expr.with_ty m ty_out) - ty_in (Expr.mark_pos m) - | ty -> Expr.evar (Var.make "undefined_input") (Expr.with_ty m ty)) - (snd (StructName.Map.find scope_arg_struct ctx.ctx_structs))) + ~fields: + (StructField.Map.map + (function + | TArrow (ty_in, ty_out), _ -> + Expr.make_abs + [| Var.make "_" |] + (Bindlib.box EEmptyError, Expr.with_ty m ty_out) + ty_in (Expr.mark_pos m) + | ty -> Expr.evar (Var.make "undefined_input") (Expr.with_ty m ty)) + (snd (StructName.Map.find scope_arg_struct ctx.ctx_structs))) m in let e_app = Expr.eapp (Expr.box e) [application_arg] m in diff --git a/compiler/plugins/modules.ml b/compiler/plugins/modules.ml index 14fce14d..1a33f544 100644 --- a/compiler/plugins/modules.ml +++ b/compiler/plugins/modules.ml @@ -115,7 +115,7 @@ let compile options link_modules optimize check_invariants = gen_ocaml options link_modules optimize check_invariants (Some modname) None in let flags = ["-I"; Lazy.force runtime_dir] in - let shared_out = File.(Filename.dirname ml_file / basename ^ ".cmxs") in + let shared_out = File.((Filename.dirname ml_file / basename) ^ ".cmxs") in Message.emit_debug "Compiling OCaml shared object file @{%s@}..." shared_out; run_process "ocamlopt" ("-shared" :: ml_file :: "-o" :: shared_out :: flags); diff --git a/compiler/scalc/print.ml b/compiler/scalc/print.ml index dcbebc2c..e8af9a36 100644 --- a/compiler/scalc/print.ml +++ b/compiler/scalc/print.ml @@ -43,8 +43,8 @@ let rec format_expr | EFunc v -> Format.fprintf fmt "%a" format_func_name v | EStruct (es, s) -> let path, fields = StructName.Map.find s decl_ctx.ctx_structs in - Format.fprintf fmt "@[%a%a@ %a%a%a@]" Print.path path StructName.format s - Print.punctuation "{" + Format.fprintf fmt "@[%a%a@ %a%a%a@]" Print.path path + StructName.format s Print.punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (e, (struct_field, _)) -> @@ -150,13 +150,11 @@ let rec format_statement ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt ((case, _), (arm_block, payload_name)) -> Format.fprintf fmt "%a %a%a%a@ %a @[%a@ %a@]" Print.punctuation - "|" Print.path path Print.enum_constructor case Print.punctuation ":" - format_var_name payload_name Print.punctuation "→" + "|" Print.path path Print.enum_constructor case Print.punctuation + ":" format_var_name payload_name Print.punctuation "→" (format_block decl_ctx ~debug) arm_block)) - (List.combine - (EnumConstructor.Map.bindings cons) - arms) + (List.combine (EnumConstructor.Map.bindings cons) arms) and format_block (decl_ctx : decl_ctx) diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 59765149..ebc91349 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -274,9 +274,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : | EVar v -> format_var fmt v | EFunc f -> format_func_name fmt f | EStruct (es, s) -> - let path, fields = - StructName.Map.find s ctx.ctx_structs - in + let path, fields = StructName.Map.find s ctx.ctx_structs in Format.fprintf fmt "%a%a(%a)" Print.path path format_struct_name s (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") @@ -442,9 +440,9 @@ let rec format_statement ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") (fun fmt (case_block, payload_var, cons_name) -> Format.fprintf fmt "%a.code == %a%a_Code.%a:@\n%a = %a.value@\n%a" - format_var tmp_var Print.path path format_enum_name e_name format_enum_cons_name - cons_name format_var payload_var format_var tmp_var - (format_block ctx) case_block)) + format_var tmp_var Print.path path format_enum_name e_name + format_enum_cons_name cons_name format_var payload_var format_var + tmp_var (format_block ctx) case_block)) cases | SReturn e1 -> Format.fprintf fmt "@[return %a@]" (format_expression ctx) diff --git a/compiler/scopelang/ast.ml b/compiler/scopelang/ast.ml index 163dbee4..d57c9661 100644 --- a/compiler/scopelang/ast.ml +++ b/compiler/scopelang/ast.ml @@ -75,29 +75,28 @@ let type_program (prg : 'm program) : typed program = let typing_env = TopdefName.Map.fold (fun name (_, ty) -> Typing.Env.add_toplevel_var name ty) - prg.program_topdefs - typing_env + prg.program_topdefs typing_env in let typing_env = ScopeName.Map.fold (fun scope_name scope_decl -> - let vars = ScopeVar.Map.map fst (Mark.remove scope_decl).scope_sig in - Typing.Env.add_scope scope_name ~vars) + let vars = ScopeVar.Map.map fst (Mark.remove scope_decl).scope_sig in + Typing.Env.add_scope scope_name ~vars) prg.program_scopes typing_env in typing_env in let rec build_typing_env prg = - ModuleName.Map.fold (fun modname prg -> + ModuleName.Map.fold + (fun modname prg -> Typing.Env.add_module modname ~module_env:(build_typing_env prg)) - prg.program_modules - (base_typing_env prg) + prg.program_modules (base_typing_env prg) in let typing_env = - ModuleName.Map.fold (fun modname prg -> + ModuleName.Map.fold + (fun modname prg -> Typing.Env.add_module modname ~module_env:(build_typing_env prg)) - prg.program_modules - (base_typing_env prg) + prg.program_modules (base_typing_env prg) in let program_topdefs = TopdefName.Map.map @@ -111,17 +110,17 @@ let type_program (prg : 'm program) : typed program = let program_scopes = ScopeName.Map.map (Mark.map (fun scope_decl -> - let typing_env = - ScopeVar.Map.fold - (fun svar (typ, _) env -> Typing.Env.add_scope_var svar typ env) - scope_decl.scope_sig typing_env - in - let scope_decl_rules = - List.map - (type_rule prg.program_ctx typing_env) - scope_decl.scope_decl_rules - in - {scope_decl with scope_decl_rules})) + let typing_env = + ScopeVar.Map.fold + (fun svar (typ, _) env -> Typing.Env.add_scope_var svar typ env) + scope_decl.scope_sig typing_env + in + let scope_decl_rules = + List.map + (type_rule prg.program_ctx typing_env) + scope_decl.scope_decl_rules + in + { scope_decl with scope_decl_rules })) prg.program_scopes in { prg with program_topdefs; program_scopes } diff --git a/compiler/scopelang/ast.mli b/compiler/scopelang/ast.mli index 0c23a772..73a8df7b 100644 --- a/compiler/scopelang/ast.mli +++ b/compiler/scopelang/ast.mli @@ -47,7 +47,9 @@ type 'm program = { program_scopes : 'm scope_decl Mark.pos ScopeName.Map.t; program_topdefs : ('m expr * typ) TopdefName.Map.t; program_modules : nil program ModuleName.Map.t; - (* Using [nil] here ensure that program interfaces don't contain any expressions. They won't contain any rules or topdefs, but will still have the scope signatures needed to respect the call convention *) + (* Using [nil] here ensure that program interfaces don't contain any + expressions. They won't contain any rules or topdefs, but will still have + the scope signatures needed to respect the call convention *) program_ctx : decl_ctx; } diff --git a/compiler/scopelang/dependency.ml b/compiler/scopelang/dependency.ml index f68921fa..70135a00 100644 --- a/compiler/scopelang/dependency.ml +++ b/compiler/scopelang/dependency.ml @@ -82,7 +82,8 @@ let rec expr_used_defs e = e VMap.empty in match e with - | ELocation (ToplevelVar { path = []; name = v, pos }), _ -> VMap.singleton (Topdef v) pos + | ELocation (ToplevelVar { path = []; name = v, pos }), _ -> + VMap.singleton (Topdef v) pos | (EScopeCall { path = []; scope; _ }, m) as e -> VMap.add (Scope scope) (Expr.mark_pos m) (recurse_subterms e) | EAbs { binder; _ }, _ -> @@ -95,7 +96,7 @@ let rule_used_defs = function (* TODO: maybe this info could be passed on from previous passes without walking through all exprs again *) expr_used_defs e - | Ast.Call ((_::_path, _), _, _) -> VMap.empty + | Ast.Call ((_ :: _path, _), _, _) -> VMap.empty | Ast.Call (([], subscope), subindex, _) -> VMap.singleton (Scope subscope) (Mark.get (SubScopeName.get_info subindex)) @@ -148,7 +149,8 @@ let build_program_dep_graph (prgm : 'm Ast.program) : SDependencies.t = used_defs g) g scope.Ast.scope_decl_rules) prgm.program_scopes g -(* TODO FIXME: Add submodules here, they may still need dependency resolution type-wise (?) *) +(* TODO FIXME: Add submodules here, they may still need dependency resolution + type-wise (?) *) let check_for_cycle_in_defs (g : SDependencies.t) : unit = (* if there is a cycle, there will be an strongly connected component of @@ -284,8 +286,7 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t Message.raise_spanned_error (Mark.get typ) "The type %a%a is defined using itself, which is forbidden \ since Catala does not provide recursive types" - Print.path path - TVertex.format used + Print.path path TVertex.format used else let edge = TDependencies.E.create used (Mark.get typ) def in TDependencies.add_edge_e g edge) @@ -307,8 +308,7 @@ let build_type_graph (structs : struct_ctx) (enums : enum_ctx) : TDependencies.t Message.raise_spanned_error (Mark.get typ) "The type %a%a is defined using itself, which is forbidden \ since Catala does not provide recursive types" - Print.path path - TVertex.format used + Print.path path TVertex.format used else let edge = TDependencies.E.create used (Mark.get typ) def in TDependencies.add_edge_e g edge) diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index a1e1a302..b45f4a88 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -35,11 +35,12 @@ type ctx = { let rec module_ctx ctx = function | [] -> ctx - | (modname, mpos) :: path -> + | (modname, mpos) :: path -> ( match ModuleName.Map.find_opt modname ctx.modules with | None -> - Message.raise_spanned_error mpos "Module %a not found" ModuleName.format modname - | Some ctx -> module_ctx ctx path + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format + modname + | Some ctx -> module_ctx ctx path) let tag_with_log_entry (e : untyped Ast.expr boxed) @@ -51,8 +52,7 @@ let tag_with_log_entry [e] (Mark.get e) else e -let rec translate_expr (ctx : ctx) (e : D.expr) : - untyped Ast.expr boxed = +let rec translate_expr (ctx : ctx) (e : D.expr) : untyped Ast.expr boxed = let m = Mark.get e in match Mark.remove e with | EVar v -> Expr.evar (Var.Map.find v ctx.var_mapping) m @@ -75,24 +75,30 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : | WholeVar new_s_var -> Mark.copy var new_s_var | States states -> Mark.copy var (snd (List.hd (List.rev states))) in - Expr.elocation (SubScopeVar { path; scope; alias; var}) m + Expr.elocation (SubScopeVar { path; scope; alias; var }) m | ELocation (DesugaredScopeVar { name; state = None }) -> Expr.elocation (ScopelangScopeVar - { name = - match ScopeVar.Map.find (Mark.remove name) ctx.scope_var_mapping - with - | WholeVar new_s_var -> Mark.copy name new_s_var - | States _ -> failwith "should not happen" } ) + { + name = + (match + ScopeVar.Map.find (Mark.remove name) ctx.scope_var_mapping + with + | WholeVar new_s_var -> Mark.copy name new_s_var + | States _ -> failwith "should not happen"); + }) m | ELocation (DesugaredScopeVar { name; state = Some state }) -> - Expr.elocation + Expr.elocation (ScopelangScopeVar - { name = - match ScopeVar.Map.find (Mark.remove name) ctx.scope_var_mapping - with - | WholeVar _ -> failwith "should not happen" - | States states -> Mark.copy name (List.assoc state states) }) + { + name = + (match + ScopeVar.Map.find (Mark.remove name) ctx.scope_var_mapping + with + | WholeVar _ -> failwith "should not happen" + | States states -> Mark.copy name (List.assoc state states)); + }) m | ELocation (ToplevelVar v) -> Expr.elocation (ToplevelVar v) m | EDStructAccess { name_opt = None; _ } -> @@ -117,19 +123,20 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : Expr.estructaccess ~e:e' ~field ~name m | EScopeCall { path; scope; args } -> Expr.escopecall ~path ~scope - ~args:(ScopeVar.Map.fold - (fun v e args' -> - let v' = - match ScopeVar.Map.find v ctx.scope_var_mapping with - | WholeVar v' -> v' - | States ((_, v') :: _) -> - (* When there are multiple states, the input is always the first - one *) - v' - | States [] -> assert false - in - ScopeVar.Map.add v' (translate_expr ctx e) args') - args ScopeVar.Map.empty) + ~args: + (ScopeVar.Map.fold + (fun v e args' -> + let v' = + match ScopeVar.Map.find v ctx.scope_var_mapping with + | WholeVar v' -> v' + | States ((_, v') :: _) -> + (* When there are multiple states, the input is always the + first one *) + v' + | States [] -> assert false + in + ScopeVar.Map.add v' (translate_expr ctx e) args') + args ScopeVar.Map.empty) m | EApp { f = EOp { op; tys }, m1; args } -> let args = List.map (translate_expr ctx) args in @@ -146,7 +153,7 @@ let rec translate_expr (ctx : ctx) (e : D.expr) : | EOp _ -> assert false (* Only allowed within [EApp] *) | ( EStruct _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | ELit _ | EApp _ | EDefault _ | EIfThenElse _ | EArray _ | EEmptyError - | EErrorOnEmpty _) as e -> + | EErrorOnEmpty _ ) as e -> Expr.map ~f:(translate_expr ctx) (e, m) (** {1 Rule tree construction} *) @@ -173,9 +180,7 @@ let def_to_exception_graph let rule_to_exception_graph (scope : D.scope) = function | Desugared.Dependency.Vertex.Var (var, state) -> ( let scope_def = - D.ScopeDef.Map.find - (D.ScopeDef.Var (var, state)) - scope.scope_defs + D.ScopeDef.Map.find (D.ScopeDef.Var (var, state)) scope.scope_defs in let var_def = scope_def.D.scope_def_rules in match Mark.remove scope_def.D.scope_def_io.io_input with @@ -195,9 +200,7 @@ let rule_to_exception_graph (scope : D.scope) = function | _ -> D.ScopeDef.Map.singleton (D.ScopeDef.Var (var, state)) - (def_to_exception_graph - (D.ScopeDef.Var (var, state)) - var_def)) + (def_to_exception_graph (D.ScopeDef.Var (var, state)) var_def)) | Desugared.Dependency.Vertex.SubScope sub_scope_index -> (* Before calling the sub_scope, we need to include all the re-definitions of subscope parameters*) @@ -211,9 +214,7 @@ let rule_to_exception_graph (scope : D.scope) = function (* We exclude subscope variables that have 0 re-definitions and are not visible in the input of the subscope *) && not - ((match - Mark.remove scope_def.D.scope_def_io.io_input - with + ((match Mark.remove scope_def.D.scope_def_io.io_input with | NoInput -> true | _ -> false) && RuleName.Map.is_empty scope_def.scope_def_rules)) @@ -230,9 +231,7 @@ let rule_to_exception_graph (scope : D.scope) = function (* This definition redefines a variable of the correct subscope. But we have to check that this redefinition is allowed with respect to the io parameters of that subscope variable. *) - (match - Mark.remove scope_def.D.scope_def_io.io_input - with + (match Mark.remove scope_def.D.scope_def_io.io_input with | NoInput -> Message.raise_multispanned_error (( Some "Incriminated subscope:", @@ -266,13 +265,11 @@ let rule_to_exception_graph (scope : D.scope) = function List.fold_left (fun exc_graphs (new_exc_graph, subscope_var, var_pos) -> D.ScopeDef.Map.add - (D.ScopeDef.SubScopeVar - (sub_scope_index, subscope_var, var_pos)) + (D.ScopeDef.SubScopeVar (sub_scope_index, subscope_var, var_pos)) new_exc_graph exc_graphs) D.ScopeDef.Map.empty (D.ScopeDef.Map.values sub_scope_vars_redefs) - | Assertion _ -> - D.ScopeDef.Map.empty (* no exceptions for assertions *) + | Assertion _ -> D.ScopeDef.Map.empty (* no exceptions for assertions *) let scope_to_exception_graphs (scope : D.scope) : Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t = @@ -351,9 +348,7 @@ let rec rule_tree_to_expr (* because each rule has its own variables parameters and we want to convert the whole rule tree into a function, we need to perform some alpha-renaming of all the expressions *) - let substitute_parameter - (e : D.expr boxed) - (rule : D.rule) : D.expr boxed = + let substitute_parameter (e : D.expr boxed) (rule : D.rule) : D.expr boxed = match params, rule.D.rule_parameter with | Some new_params, Some (old_params_with_types, _) -> let old_params, _ = List.split old_params_with_types in @@ -390,14 +385,10 @@ let rec rule_tree_to_expr ctx) in let base_just_list = - List.map - (fun rule -> substitute_parameter rule.D.rule_just rule) - base_rules + List.map (fun rule -> substitute_parameter rule.D.rule_just rule) base_rules in let base_cons_list = - List.map - (fun rule -> substitute_parameter rule.D.rule_cons rule) - base_rules + List.map (fun rule -> substitute_parameter rule.D.rule_cons rule) base_rules in let translate_and_unbox_list (list : D.expr boxed list) : untyped Ast.expr boxed list = @@ -473,24 +464,17 @@ let translate_def (* Here, we have to transform this list of rules into a default tree. *) let top_list = def_map_to_tree def exc_graph in let is_input = - match Mark.remove io.D.io_input with - | OnlyInput -> true - | _ -> false + match Mark.remove io.D.io_input with OnlyInput -> true | _ -> false in let is_reentrant = - match Mark.remove io.D.io_input with - | Reentrant -> true - | _ -> false + match Mark.remove io.D.io_input with Reentrant -> true | _ -> false in let top_value : D.rule option = if is_cond && ((not is_subscope_var) || (is_subscope_var && is_input)) then (* We add the bottom [false] value for conditions, only for the scope where the condition is declared. Except when the variable is an input, where we want the [false] to be added at each caller parent scope. *) - Some - (D.always_false_rule - (D.ScopeDef.get_position def_info) - params) + Some (D.always_false_rule (D.ScopeDef.get_position def_info) params) else None in if @@ -550,20 +534,16 @@ let translate_def exceptions to the default value *) Node (top_list, [top_value]) | [top_tree], None -> top_tree - | _, None -> - Node (top_list, [D.empty_rule (Mark.get typ) params])) + | _, None -> Node (top_list, [D.empty_rule (Mark.get typ) params])) let translate_rule ctx (scope : D.scope) (exc_graphs : - Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t) - = function + Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t) = function | Desugared.Dependency.Vertex.Var (var, state) -> ( let scope_def = - D.ScopeDef.Map.find - (D.ScopeDef.Var (var, state)) - scope.scope_defs + D.ScopeDef.Map.find (D.ScopeDef.Var (var, state)) scope.scope_defs in let var_def = scope_def.D.scope_def_rules in let var_params = scope_def.D.scope_def_parameters in @@ -613,9 +593,7 @@ let translate_rule (* We exclude subscope variables that have 0 re-definitions and are not visible in the input of the subscope *) && not - ((match - Mark.remove scope_def.D.scope_def_io.io_input - with + ((match Mark.remove scope_def.D.scope_def_io.io_input with | NoInput -> true | _ -> false) && RuleName.Map.is_empty scope_def.scope_def_rules)) @@ -633,9 +611,7 @@ let translate_rule (* This definition redefines a variable of the correct subscope. But we have to check that this redefinition is allowed with respect to the io parameters of that subscope variable. *) - (match - Mark.remove scope_def.D.scope_def_io.io_input - with + (match Mark.remove scope_def.D.scope_def_io.io_input with | NoInput -> assert false (* error already raised *) | OnlyInput when RuleName.Map.is_empty def && not is_cond -> assert false (* error already raised *) @@ -652,28 +628,28 @@ let translate_rule SubScopeName.Map.find sub_scope_index scope.scope_sub_scopes in Ast.Definition - ( ( SubScopeVar { - path = subscop_path; - scope = subscop_real_name; - alias = sub_scope_index, var_pos; - var = - match - ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping - with - | WholeVar v -> v, var_pos - | States states -> - (* When defining a sub-scope variable, we always define - its first state in the sub-scope. *) - snd (List.hd states), var_pos }, + ( ( SubScopeVar + { + path = subscop_path; + scope = subscop_real_name; + alias = sub_scope_index, var_pos; + var = + (match + ScopeVar.Map.find sub_scope_var ctx.scope_var_mapping + with + | WholeVar v -> v, var_pos + | States states -> + (* When defining a sub-scope variable, we always + define its first state in the sub-scope. *) + snd (List.hd states), var_pos); + }, var_pos ), def_typ, scope_def.D.scope_def_io, Expr.unbox expr_def )) sub_scope_vars_redefs_candidates in - let sub_scope_vars_redefs = - D.ScopeDef.Map.values sub_scope_vars_redefs - in + let sub_scope_vars_redefs = D.ScopeDef.Map.values sub_scope_vars_redefs in sub_scope_vars_redefs @ [ Ast.Call @@ -698,9 +674,7 @@ let translate_scope_interface ctx scope = match states with | WholeVar -> let scope_def = - D.ScopeDef.Map.find - (D.ScopeDef.Var (var, None)) - scope.D.scope_defs + D.ScopeDef.Map.find (D.ScopeDef.Var (var, None)) scope.D.scope_defs in let typ = scope_def.scope_def_typ in ScopeVar.Map.add @@ -731,19 +705,18 @@ let translate_scope_interface ctx scope = in let pos = Mark.get (ScopeName.get_info scope.scope_uid) in Mark.add pos - { - Ast.scope_decl_name = scope.scope_uid; - Ast.scope_decl_rules = []; - Ast.scope_sig; - Ast.scope_options = scope.scope_options; - } + { + Ast.scope_decl_name = scope.scope_uid; + Ast.scope_decl_rules = []; + Ast.scope_sig; + Ast.scope_options = scope.scope_options; + } let translate_scope (ctx : ctx) (exc_graphs : Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t) - (scope : D.scope) - : untyped Ast.scope_decl Mark.pos = + (scope : D.scope) : untyped Ast.scope_decl Mark.pos = let scope_dependencies = Desugared.Dependency.build_scope_dependencies scope in @@ -758,7 +731,8 @@ let translate_scope scope_decl_rules @ new_rules) [] scope_ordering in - Mark.map (fun s -> { s with Ast.scope_decl_rules }) + Mark.map + (fun s -> { s with Ast.scope_decl_rules }) (translate_scope_interface ctx scope) (** {1 API} *) @@ -766,8 +740,8 @@ let translate_scope let translate_program (desugared : D.program) (exc_graphs : - Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t) - : untyped Ast.program = + Desugared.Dependency.ExceptionsDependencies.t D.ScopeDef.Map.t) : + untyped Ast.program = (* First we give mappings to all the locations between Desugared and This involves creating a new Scopelang scope variable for every state of a Desugared variable. *) @@ -777,27 +751,26 @@ let translate_program have different types for Desugared.ScopeVar.t and Scopelang.ScopeVar.t *) ScopeName.Map.fold (fun _scope scope_decl ctx -> - ScopeVar.Map.fold - (fun scope_var (states : D.var_or_states) ctx -> - let var_name, var_pos = ScopeVar.get_info scope_var in - let new_var = - match states with - | D.WholeVar -> - WholeVar (ScopeVar.fresh (var_name, var_pos)) - | States states -> - let var_prefix = var_name ^ "_" in - let state_var state = - ScopeVar.fresh - (Mark.map (( ^ ) var_prefix) (StateName.get_info state)) - in - States (List.map (fun state -> state, state_var state) states) - in - { - ctx with - scope_var_mapping = - ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping; - }) - scope_decl.D.scope_vars ctx) + ScopeVar.Map.fold + (fun scope_var (states : D.var_or_states) ctx -> + let var_name, var_pos = ScopeVar.get_info scope_var in + let new_var = + match states with + | D.WholeVar -> WholeVar (ScopeVar.fresh (var_name, var_pos)) + | States states -> + let var_prefix = var_name ^ "_" in + let state_var state = + ScopeVar.fresh + (Mark.map (( ^ ) var_prefix) (StateName.get_info state)) + in + States (List.map (fun state -> state, state_var state) states) + in + { + ctx with + scope_var_mapping = + ScopeVar.Map.add scope_var new_var ctx.scope_var_mapping; + }) + scope_decl.D.scope_vars ctx) desugared.D.program_scopes { scope_var_mapping = ScopeVar.Map.empty; @@ -808,53 +781,63 @@ let translate_program in let ctx = make_ctx desugared in let rec gather_scope_vars acc modules = - ModuleName.Map.fold (fun _modname mctx acc -> + ModuleName.Map.fold + (fun _modname mctx acc -> let acc = gather_scope_vars acc mctx.modules in ScopeVar.Map.union (fun _ _ -> assert false) acc mctx.scope_var_mapping) modules acc in - let ctx = { ctx with scope_var_mapping = gather_scope_vars ctx.scope_var_mapping ctx.modules } in + let ctx = + { + ctx with + scope_var_mapping = gather_scope_vars ctx.scope_var_mapping ctx.modules; + } + in let rec process_decl_ctx ctx decl_ctx = let ctx_scopes = ScopeName.Map.map (fun out_str -> - let out_struct_fields = - ScopeVar.Map.fold - (fun var fld out_map -> - let var' = - match ScopeVar.Map.find var ctx.scope_var_mapping with - | WholeVar v -> v - | States l -> snd (List.hd (List.rev l)) - in - ScopeVar.Map.add var' fld out_map) - out_str.out_struct_fields ScopeVar.Map.empty - in - { out_str with out_struct_fields }) + let out_struct_fields = + ScopeVar.Map.fold + (fun var fld out_map -> + let var' = + match ScopeVar.Map.find var ctx.scope_var_mapping with + | WholeVar v -> v + | States l -> snd (List.hd (List.rev l)) + in + ScopeVar.Map.add var' fld out_map) + out_str.out_struct_fields ScopeVar.Map.empty + in + { out_str with out_struct_fields }) decl_ctx.ctx_scopes in - { decl_ctx with + { + decl_ctx with ctx_modules = - ModuleName.Map.mapi (fun modname decl_ctx -> + ModuleName.Map.mapi + (fun modname decl_ctx -> let ctx = ModuleName.Map.find modname ctx.modules in process_decl_ctx ctx decl_ctx) decl_ctx.ctx_modules; - ctx_scopes; } + ctx_scopes; + } in let rec process_modules program_ctx desugared = - ModuleName.Map.mapi (fun modname m_desugared -> + ModuleName.Map.mapi + (fun modname m_desugared -> let ctx = ModuleName.Map.find modname ctx.modules in { - Ast.program_topdefs = TopdefName.Map.empty; - program_scopes = - ScopeName.Map.map - (translate_scope_interface ctx) - m_desugared.D.program_scopes; - program_ctx; - program_modules = - process_modules - (ModuleName.Map.find modname program_ctx.ctx_modules) - m_desugared; - }) + Ast.program_topdefs = TopdefName.Map.empty; + program_scopes = + ScopeName.Map.map + (translate_scope_interface ctx) + m_desugared.D.program_scopes; + program_ctx; + program_modules = + process_modules + (ModuleName.Map.find modname program_ctx.ctx_modules) + m_desugared; + }) desugared.D.program_modules in let program_ctx = process_decl_ctx ctx desugared.D.program_ctx in @@ -862,14 +845,16 @@ let translate_program let program_topdefs = TopdefName.Map.mapi (fun id -> function - | Some e, ty -> Expr.unbox (translate_expr ctx e), ty - | None, (_, pos) -> - Message.raise_spanned_error pos "No definition found for %a" - TopdefName.format id) + | Some e, ty -> Expr.unbox (translate_expr ctx e), ty + | None, (_, pos) -> + Message.raise_spanned_error pos "No definition found for %a" + TopdefName.format id) desugared.program_topdefs in let program_scopes = - ScopeName.Map.map (translate_scope ctx exc_graphs) desugared.D.program_scopes + ScopeName.Map.map + (translate_scope ctx exc_graphs) + desugared.D.program_scopes in { Ast.program_topdefs; diff --git a/compiler/scopelang/print.ml b/compiler/scopelang/print.ml index 051ba69b..ec25575f 100644 --- a/compiler/scopelang/print.ml +++ b/compiler/scopelang/print.ml @@ -22,10 +22,10 @@ let struc ctx (fmt : Format.formatter) (name : StructName.t) - (path, fields : path * typ StructField.Map.t) : unit = - Format.fprintf fmt "%a %a%a %a %a@\n@[ %a@]@\n%a" Print.keyword "struct" - Print.path path - StructName.format name Print.punctuation "=" Print.punctuation "{" + ((path, fields) : path * typ StructField.Map.t) : unit = + Format.fprintf fmt "%a %a%a %a %a@\n@[ %a@]@\n%a" Print.keyword + "struct" Print.path path StructName.format name Print.punctuation "=" + Print.punctuation "{" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (field_name, typ) -> @@ -38,10 +38,9 @@ let enum ctx (fmt : Format.formatter) (name : EnumName.t) - (path, cases : path * typ EnumConstructor.Map.t) : unit = + ((path, cases) : path * typ EnumConstructor.Map.t) : unit = Format.fprintf fmt "%a %a%a %a @\n@[ %a@]" Print.keyword "enum" - Print.path path - EnumName.format name Print.punctuation "=" + Print.path path EnumName.format name Print.punctuation "=" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (fun fmt (field_name, typ) -> @@ -95,9 +94,8 @@ let scope ?debug ctx fmt (name, (decl, _pos)) = Format.fprintf fmt "%a %a" Print.keyword "assert" (Print.expr ?debug ()) e | Call ((scope_path, scope_name), subscope_name, _) -> - Format.fprintf fmt "%a %a%a%a%a%a" Print.keyword "call" - Print.path scope_path - ScopeName.format scope_name Print.punctuation "[" + Format.fprintf fmt "%a %a%a%a%a%a" Print.keyword "call" Print.path + scope_path ScopeName.format scope_name Print.punctuation "[" SubScopeName.format subscope_name Print.punctuation "]")) decl.scope_decl_rules diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 5a3462bb..76674d3f 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -23,7 +23,9 @@ open Catala_utils module Runtime = Runtime_ocaml.Runtime module ModuleName = String -(* TODO: should probably be turned into an Uid once we implement module import directives; that will incur an additional resolution work on all paths though *) +(* TODO: should probably be turned into an Uid once we implement module import + directives; that will incur an additional resolution work on all paths + though *) module ScopeName = Uid.Gen () module TopdefName = Uid.Gen () @@ -314,8 +316,10 @@ type except = ConflictError | EmptyError | NoValueProvided | Crash type untyped = { pos : Pos.t } [@@caml.unboxed] type typed = { pos : Pos.t; ty : typ } type 'a custom = { pos : Pos.t; custom : 'a } + +(** Using empty markings will ensure terms can't be constructed: used for + example in interfaces to ensure that they don't contain any expressions *) type nil = | - (** Using empty markings will ensure terms can't be constructed: used for example in interfaces to ensure that they don't contain any expressions *) (** The generic type of AST markings. Using a GADT allows functions to be polymorphic in the marking, but still do transformations on types when @@ -346,24 +350,34 @@ type lit = type path = ModuleName.t Mark.pos list -(** External references are resolved to strings that point to functions or constants in the end, but we need to keep different references for typing *) +(** External references are resolved to strings that point to functions or + constants in the end, but we need to keep different references for typing *) type external_ref = | External_value of TopdefName.t | External_scope of ScopeName.t (** Locations are handled differently in [desugared] and [scopelang] *) type 'a glocation = - | DesugaredScopeVar : - { name: ScopeVar.t Mark.pos; state: StateName.t option } + | DesugaredScopeVar : { + name : ScopeVar.t Mark.pos; + state : StateName.t option; + } -> < scopeVarStates : yes ; .. > glocation - | ScopelangScopeVar : - { name: ScopeVar.t Mark.pos } + | ScopelangScopeVar : { + name : ScopeVar.t Mark.pos; + } -> < scopeVarSimpl : yes ; .. > glocation - | SubScopeVar : - { path: path; scope: ScopeName.t; alias: SubScopeName.t Mark.pos; var: ScopeVar.t Mark.pos } + | SubScopeVar : { + path : path; + scope : ScopeName.t; + alias : SubScopeName.t Mark.pos; + var : ScopeVar.t Mark.pos; + } -> < explicitScopes : yes ; .. > glocation - | ToplevelVar : - { path: path; name: TopdefName.t Mark.pos } + | ToplevelVar : { + path : path; + name : TopdefName.t Mark.pos; + } -> < explicitScopes : yes ; .. > glocation type ('a, 'm) gexpr = (('a, 'm) naked_gexpr, 'm) marked @@ -463,7 +477,11 @@ and ('a, 'b, 'm) base_gexpr = -> ('a, < resolvedNames : yes ; .. >, 'm) base_gexpr (** Resolved struct/enums, after [desugared] *) (* Lambda-like *) - | EExternal : { path: path; name: external_ref Mark.pos} -> ('a, < explicitScopes: no ; .. >, 't) base_gexpr + | EExternal : { + path : path; + name : external_ref Mark.pos; + } + -> ('a, < explicitScopes : no ; .. >, 't) base_gexpr | EAssert : ('a, 'm) gexpr -> ('a, < assertions : yes ; .. >, 'm) base_gexpr (* Default terms *) | EDefault : { @@ -595,7 +613,4 @@ type decl_ctx = { ctx_modules : decl_ctx ModuleName.Map.t; } -type 'e program = { - decl_ctx : decl_ctx; - code_items : 'e code_item_list; -} +type 'e program = { decl_ctx : decl_ctx; code_items : 'e code_item_list } diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index 10ff445d..efe832bd 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -109,7 +109,10 @@ let subst binder vars = Bindlib.msubst binder (Array.of_list (List.map Mark.remove vars)) let evar v mark = Mark.add mark (Bindlib.box_var v) -let eexternal ~path ~name mark = Mark.add mark (Bindlib.box (EExternal {path; name})) + +let eexternal ~path ~name mark = + Mark.add mark (Bindlib.box (EExternal { path; name })) + let etuple args = Box.appn args @@ fun args -> ETuple args let etupleaccess e index size = @@ -296,8 +299,7 @@ let map estruct ~name ~fields m | EDStructAccess { path; name_opt; field; e } -> edstructaccess ~path ~name_opt ~field ~e:(f e) m - | EStructAccess { name; field; e } -> - estructaccess ~name ~field ~e:(f e) m + | EStructAccess { name; field; e } -> estructaccess ~name ~field ~e:(f e) m | EMatch { name; e; cases } -> let cases = EnumConstructor.Map.map f cases in ematch ~name ~e:(f e) ~cases m @@ -389,7 +391,7 @@ let map_gather acc, etupleaccess e index size m | EInj { name; cons; e } -> let acc, e = f e in - acc, einj ~name ~cons ~e m + acc, einj ~name ~cons ~e m | EAssert e -> let acc, e = f e in acc, eassert e m @@ -516,31 +518,33 @@ let compare_lit (l1 : lit) (l2 : lit) = | LDuration _, _ -> . | _, LDuration _ -> . -let compare_path = - List.compare (Mark.compare ModuleName.compare) +let compare_path = List.compare (Mark.compare ModuleName.compare) let compare_location (type a) (x : a glocation Mark.pos) (y : a glocation Mark.pos) = match Mark.remove x, Mark.remove y with - | DesugaredScopeVar { name = vx; state = None}, DesugaredScopeVar { name = vy; state = None} - | DesugaredScopeVar { name = vx; state = Some _}, DesugaredScopeVar { name = vy; state = None} - | DesugaredScopeVar { name = vx; state = None}, DesugaredScopeVar { name = vy; state = Some _} -> + | ( DesugaredScopeVar { name = vx; state = None }, + DesugaredScopeVar { name = vy; state = None } ) + | ( DesugaredScopeVar { name = vx; state = Some _ }, + DesugaredScopeVar { name = vy; state = None } ) + | ( DesugaredScopeVar { name = vx; state = None }, + DesugaredScopeVar { name = vy; state = Some _ } ) -> ScopeVar.compare (Mark.remove vx) (Mark.remove vy) - | DesugaredScopeVar {name = (x, _); state = Some sx}, DesugaredScopeVar {name = (y, _); state = Some sy} -> + | ( DesugaredScopeVar { name = x, _; state = Some sx }, + DesugaredScopeVar { name = y, _; state = Some sy } ) -> let cmp = ScopeVar.compare x y in if cmp = 0 then StateName.compare sx sy else cmp - | ScopelangScopeVar { name = (vx, _) }, ScopelangScopeVar { name = (vy, _) } -> + | ScopelangScopeVar { name = vx, _ }, ScopelangScopeVar { name = vy, _ } -> ScopeVar.compare vx vy - | ( SubScopeVar { alias = (xsubindex, _); var = (xsubvar, _); _}, - SubScopeVar { alias = (ysubindex, _); var = (ysubvar, _); _} ) -> + | ( SubScopeVar { alias = xsubindex, _; var = xsubvar, _; _ }, + SubScopeVar { alias = ysubindex, _; var = ysubvar, _; _ } ) -> let c = SubScopeName.compare xsubindex ysubindex in if c = 0 then ScopeVar.compare xsubvar ysubvar else c - | ToplevelVar { path = px; name = (vx, _) }, ToplevelVar { path = py; name = (vy, _) } -> - (match compare_path px py with - | 0 -> TopdefName.compare vx vy - | n -> n) + | ( ToplevelVar { path = px; name = vx, _ }, + ToplevelVar { path = py; name = vy, _ } ) -> ( + match compare_path px py with 0 -> TopdefName.compare vx vy | n -> n) | DesugaredScopeVar _, _ -> -1 | _, DesugaredScopeVar _ -> 1 | ScopelangScopeVar _, _ -> -1 @@ -554,11 +558,15 @@ let equal_path = List.equal (Mark.equal ModuleName.equal) let equal_location a b = compare_location a b = 0 let equal_except ex1 ex2 = ex1 = ex2 let compare_except ex1 ex2 = Stdlib.compare ex1 ex2 -let equal_external_ref ref1 ref2 = match ref1, ref2 with + +let equal_external_ref ref1 ref2 = + match ref1, ref2 with | External_value v1, External_value v2 -> TopdefName.equal v1 v2 | External_scope s1, External_scope s2 -> ScopeName.equal s1 s2 | (External_value _ | External_scope _), _ -> false -let compare_external_ref ref1 ref2 = match ref1, ref2 with + +let compare_external_ref ref1 ref2 = + match ref1, ref2 with | External_value v1, External_value v2 -> TopdefName.compare v1 v2 | External_scope s1, External_scope s2 -> ScopeName.compare s1 s2 | External_value _, _ -> -1 @@ -569,7 +577,7 @@ let compare_external_ref ref1 ref2 = match ref1, ref2 with (* weird indentation; see https://github.com/ocaml-ppx/ocamlformat/issues/2143 *) let rec equal_list : 'a. ('a, 't) gexpr list -> ('a, 't) gexpr list -> bool = - fun es1 es2 -> List.equal equal es1 es2 + fun es1 es2 -> List.equal equal es1 es2 and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = fun e1 e2 -> @@ -609,12 +617,15 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = StructName.equal s1 s2 && StructField.Map.equal equal fields1 fields2 | ( EDStructAccess { e = e1; field = f1; name_opt = s1; path = p1 }, EDStructAccess { e = e2; field = f2; name_opt = s2; path = p2 } ) -> - Option.equal StructName.equal s1 s2 && equal_path p1 p2 && Ident.equal f1 f2 && equal e1 e2 + Option.equal StructName.equal s1 s2 + && equal_path p1 p2 + && Ident.equal f1 f2 + && equal e1 e2 | ( EStructAccess { e = e1; field = f1; name = s1 }, EStructAccess { e = e2; field = f2; name = s2 } ) -> StructName.equal s1 s2 && StructField.equal f1 f2 && equal e1 e2 - | EInj { e = e1; cons = c1; name = n1 }, - EInj { e = e2; cons = c2; name = n2 } -> + | EInj { e = e1; cons = c1; name = n1 }, EInj { e = e2; cons = c2; name = n2 } + -> EnumName.equal n1 n2 && EnumConstructor.equal c1 c2 && equal e1 e2 | ( EMatch { e = e1; name = n1; cases = cases1 }, EMatch { e = e2; name = n2; cases = cases2 } ) -> @@ -623,9 +634,9 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = && EnumConstructor.Map.equal equal cases1 cases2 | ( EScopeCall { path = p1; scope = s1; args = fields1 }, EScopeCall { path = p2; scope = s2; args = fields2 } ) -> - ScopeName.equal s1 s2 && - equal_path p1 p2 && - ScopeVar.Map.equal equal fields1 fields2 + ScopeName.equal s1 s2 + && equal_path p1 p2 + && ScopeVar.Map.equal equal fields1 fields2 | ( ECustom { obj = obj1; targs = targs1; tret = tret1 }, ECustom { obj = obj2; targs = targs2; tret = tret2 } ) -> Type.equal_list targs1 targs2 && Type.equal tret1 tret2 && obj1 == obj2 diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index df34350d..01f3bc6d 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -36,7 +36,12 @@ val rebox : ('a any, 'm) gexpr -> ('a, 'm) boxed_gexpr (** Rebuild the whole term, re-binding all variables and exposing free variables *) val evar : ('a, 'm) gexpr Var.t -> 'm mark -> ('a, 'm) boxed_gexpr -val eexternal : path:path -> name:external_ref Mark.pos -> 'm mark -> (< explicitScopes: no; .. >, 'm) boxed_gexpr + +val eexternal : + path:path -> + name:external_ref Mark.pos -> + 'm mark -> + (< explicitScopes : no ; .. >, 'm) boxed_gexpr val bind : ('a, 'm) gexpr Var.t array -> @@ -108,37 +113,37 @@ val eraise : except -> 'm mark -> (< exceptions : yes ; .. >, 'm) boxed_gexpr val elocation : 'a glocation -> 'm mark -> ((< .. > as 'a), 'm) boxed_gexpr val estruct : - name: StructName.t -> - fields: ('a, 'm) boxed_gexpr StructField.Map.t -> + name:StructName.t -> + fields:('a, 'm) boxed_gexpr StructField.Map.t -> 'm mark -> ('a any, 'm) boxed_gexpr val edstructaccess : - path: path -> - name_opt: StructName.t option -> - field: Ident.t -> - e: ('a, 'm) boxed_gexpr -> + path:path -> + name_opt:StructName.t option -> + field:Ident.t -> + e:('a, 'm) boxed_gexpr -> 'm mark -> ((< syntacticNames : yes ; .. > as 'a), 'm) boxed_gexpr val estructaccess : - name: StructName.t -> - field: StructField.t -> - e: ('a, 'm) boxed_gexpr -> + name:StructName.t -> + field:StructField.t -> + e:('a, 'm) boxed_gexpr -> 'm mark -> ((< resolvedNames : yes ; .. > as 'a), 'm) boxed_gexpr val einj : - name: EnumName.t -> - cons: EnumConstructor.t -> - e: ('a, 'm) boxed_gexpr -> + name:EnumName.t -> + cons:EnumConstructor.t -> + e:('a, 'm) boxed_gexpr -> 'm mark -> ('a any, 'm) boxed_gexpr val ematch : - name: EnumName.t -> - e: ('a, 'm) boxed_gexpr -> - cases: ('a, 'm) boxed_gexpr EnumConstructor.Map.t -> + name:EnumName.t -> + e:('a, 'm) boxed_gexpr -> + cases:('a, 'm) boxed_gexpr EnumConstructor.Map.t -> 'm mark -> ('a any, 'm) boxed_gexpr diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index 22254725..200f6cfc 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -549,31 +549,33 @@ let rec evaluate_expr : Message.raise_spanned_error pos "free variable found at evaluation (should not happen if term was \ well-typed)" - | EExternal { path; name } -> ( - let ty = - try - let ctx = Program.module_ctx ctx path in - match Mark.remove name with - | External_value name -> - TopdefName.Map.find name ctx.ctx_topdefs - | External_scope name -> - let scope_info = ScopeName.Map.find name ctx.ctx_scopes in - TArrow ([TStruct scope_info.in_struct_name, pos], - (TStruct scope_info.out_struct_name, pos)), - pos - with TopdefName.Map.Not_found _ | ScopeName.Map.Not_found _ -> - Message.raise_spanned_error pos "Reference to %a%a could not be resolved" - Print.path path Print.external_ref name - in - let runtime_path = - List.map Mark.remove path, + | EExternal { path; name } -> + let ty = + try + let ctx = Program.module_ctx ctx path in + match Mark.remove name with + | External_value name -> TopdefName.Map.find name ctx.ctx_topdefs + | External_scope name -> + let scope_info = ScopeName.Map.find name ctx.ctx_scopes in + ( TArrow + ( [TStruct scope_info.in_struct_name, pos], + (TStruct scope_info.out_struct_name, pos) ), + pos ) + with TopdefName.Map.Not_found _ | ScopeName.Map.Not_found _ -> + Message.raise_spanned_error pos + "Reference to %a%a could not be resolved" Print.path path + Print.external_ref name + in + let runtime_path = + ( List.map Mark.remove path, match Mark.remove name with | External_value name -> Mark.remove (TopdefName.get_info name) - | External_scope name -> Mark.remove (ScopeName.get_info name) - (* we have the guarantee that the two cases won't collide because they have different capitalisation rules inherited from the input *) - in - let o = Runtime.lookup_value runtime_path in - runtime_to_val evaluate_expr ctx m ty o) + | External_scope name -> Mark.remove (ScopeName.get_info name) ) + (* we have the guarantee that the two cases won't collide because they + have different capitalisation rules inherited from the input *) + in + let o = Runtime.lookup_value runtime_path in + runtime_to_val evaluate_expr ctx m ty o | EApp { f = e1; args } -> ( let e1 = evaluate_expr ctx e1 in let args = List.map (evaluate_expr ctx) args in diff --git a/compiler/shared_ast/optimizations.ml b/compiler/shared_ast/optimizations.ml index e06c1eb6..c614dfeb 100644 --- a/compiler/shared_ast/optimizations.ml +++ b/compiler/shared_ast/optimizations.ml @@ -409,12 +409,15 @@ let test_iota_reduction_2 () = let matchA = Expr.ematch - ~e:(Expr.ematch ~e:(num 1) ~name:enumT - ~cases:(cases_of_list - [ - (consB, fun x -> injBe (injB x)); (consA, fun _x -> injAe (num 20)); - ]) - nomark) + ~e: + (Expr.ematch ~e:(num 1) ~name:enumT + ~cases: + (cases_of_list + [ + (consB, fun x -> injBe (injB x)); + (consA, fun _x -> injAe (num 20)); + ]) + nomark) ~name:enumT ~cases:(cases_of_list [consA, injC; consB, injD]) nomark diff --git a/compiler/shared_ast/print.ml b/compiler/shared_ast/print.ml index c2502c69..0c8ac20d 100644 --- a/compiler/shared_ast/print.ml +++ b/compiler/shared_ast/print.ml @@ -73,20 +73,20 @@ let tlit (fmt : Format.formatter) (l : typ_lit) : unit = let module_name ppf m = Format.fprintf ppf "@{%a@}" ModuleName.format m let path ppf p = - Format.pp_print_list ~pp_sep:(fun _ () -> ()) + Format.pp_print_list + ~pp_sep:(fun _ () -> ()) (fun ppf m -> - Format.fprintf ppf "%a@{.@}" - module_name (Mark.remove m)) + Format.fprintf ppf "%a@{.@}" module_name (Mark.remove m)) ppf p let location (type a) (fmt : Format.formatter) (l : a glocation) : unit = match l with | DesugaredScopeVar { name; _ } -> ScopeVar.format fmt (Mark.remove name) | ScopelangScopeVar { name; _ } -> ScopeVar.format fmt (Mark.remove name) - | SubScopeVar { alias=subindex; var=subvar; _ } -> + | SubScopeVar { alias = subindex; var = subvar; _ } -> Format.fprintf fmt "%a.%a" SubScopeName.format (Mark.remove subindex) ScopeVar.format (Mark.remove subvar) - | ToplevelVar { path=p; name } -> + | ToplevelVar { path = p; name } -> path fmt p; TopdefName.format fmt (Mark.remove name) @@ -103,11 +103,12 @@ let external_ref fmt er = let rec module_ctx ctx = function | [] -> ctx - | (modname, mpos) :: path -> + | (modname, mpos) :: path -> ( match ModuleName.Map.find_opt modname ctx.ctx_modules with | None -> - Message.raise_spanned_error mpos "Module %a not found" ModuleName.format modname - | Some ctx -> module_ctx ctx path + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format + modname + | Some ctx -> module_ctx ctx path) let rec typ_gen (ctx : decl_ctx option) @@ -137,15 +138,14 @@ let rec typ_gen pp_color_string (List.hd colors) fmt ")" | TStruct s -> ( match ctx with - | None -> - StructName.format fmt s + | None -> StructName.format fmt s | Some ctx -> let p, fields = StructName.Map.find s ctx.ctx_structs in - if StructField.Map.is_empty fields then - (path fmt p; StructName.format fmt s) + if StructField.Map.is_empty fields then ( + path fmt p; + StructName.format fmt s) else - Format.fprintf fmt "@[%a%a %a@,%a@;<0 -2>%a@]" - path p + Format.fprintf fmt "@[%a%a %a@,%a@;<0 -2>%a@]" path p StructName.format s (pp_color_string (List.hd colors)) "{" @@ -166,14 +166,14 @@ let rec typ_gen | None -> Format.fprintf fmt "@[%a@]" EnumName.format e | Some ctx -> let p, def = EnumName.Map.find e ctx.ctx_enums in - Format.fprintf fmt "@[%a%a%a%a%a@]" path p EnumName.format e punctuation "[" + Format.fprintf fmt "@[%a%a%a%a%a@]" path p EnumName.format e + punctuation "[" (EnumConstructor.Map.format_bindings ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " punctuation "|") (fun fmt pp_case mty -> Format.fprintf fmt "%t%a@ %a" pp_case punctuation ":" (typ ~colors) mty)) - def - punctuation "]") + def punctuation "]") | TOption t -> Format.fprintf fmt "@[%a@ %a@]" base_type "eoption" (typ ~colors) t | TArrow ([t1], t2) -> @@ -528,7 +528,7 @@ module ExprGen (C : EXPR_PARAM) = struct else match Mark.remove e with | EVar v -> var fmt v - | EExternal {path=p; name} -> + | EExternal { path = p; name } -> path fmt p; external_ref fmt name | ETuple es -> @@ -871,9 +871,9 @@ let enum decl_ctx fmt (pp_name : Format.formatter -> unit) - (p, c : path * typ EnumConstructor.Map.t) = - Format.fprintf fmt "@[%a %a%t %a@ %a@]" keyword "type" path p pp_name punctuation - "=" + ((p, c) : path * typ EnumConstructor.Map.t) = + Format.fprintf fmt "@[%a %a%t %a@ %a@]" keyword "type" path p pp_name + punctuation "=" (EnumConstructor.Map.format_bindings ~pp_sep:(fun _ _ -> ()) (fun fmt pp_n ty -> @@ -888,7 +888,7 @@ let struct_ decl_ctx fmt (pp_name : Format.formatter -> unit) - (p, c : path * typ StructField.Map.t) = + ((p, c) : path * typ StructField.Map.t) = Format.fprintf fmt "@[@[@[%a %a%t %a@;%a@]@;%a@]%a@]@;" keyword "type" path p pp_name punctuation "=" punctuation "{" (StructField.Map.format_bindings diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index f78893c3..20fce92e 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -41,11 +41,12 @@ let empty_ctx = let rec module_ctx ctx = function | [] -> ctx - | (modname, mpos) :: path -> + | (modname, mpos) :: path -> ( match ModuleName.Map.find_opt modname ctx.ctx_modules with | None -> - Message.raise_spanned_error mpos "Module %a not found" ModuleName.format modname - | Some ctx -> module_ctx ctx path + Message.raise_spanned_error mpos "Module %a not found" ModuleName.format + modname + | Some ctx -> module_ctx ctx path) let get_scope_body { code_items; _ } scope = match diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index 5b5252cb..dda3553e 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -23,7 +23,8 @@ open Definitions val empty_ctx : decl_ctx val module_ctx : decl_ctx -> ModuleName.t Mark.pos list -> decl_ctx -(** Follows a path to get the corresponding context for type and value declarations. Errors out if the module is not found *) +(** Follows a path to get the corresponding context for type and value + declarations. Errors out if the module is not found *) (** {2 Transformations} *) diff --git a/compiler/shared_ast/scope.mli b/compiler/shared_ast/scope.mli index 143b3d54..9c65cb0f 100644 --- a/compiler/shared_ast/scope.mli +++ b/compiler/shared_ast/scope.mli @@ -1,5 +1,5 @@ -(* This file is part of the Catala compiler, a specification language for tax -< and social benefits computation rules. Copyright (C) 2020-2022 Inria, +(* This file is part of the Catala compiler, a specification language for tax < + and social benefits computation rules. Copyright (C) 2020-2022 Inria, contributor: Denis Merigoux , Alain Delaët-Tixeuil , Louis Gesbert diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index d83d3b1c..9c20358a 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -125,13 +125,16 @@ let rec format_typ "(" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ") - (fun fmt t -> - format_typ fmt ~colors:(List.tl colors) t)) + (fun fmt t -> format_typ fmt ~colors:(List.tl colors) t)) ts (pp_color_string (List.hd colors)) ")" - | TStruct s -> Print.path fmt (fst (A.StructName.Map.find s ctx.A.ctx_structs)); A.StructName.format fmt s - | TEnum e -> Print.path fmt (fst (A.EnumName.Map.find e ctx.A.ctx_enums)); A.EnumName.format fmt e + | TStruct s -> + Print.path fmt (fst (A.StructName.Map.find s ctx.A.ctx_structs)); + A.StructName.format fmt s + | TEnum e -> + Print.path fmt (fst (A.EnumName.Map.find e ctx.A.ctx_enums)); + A.EnumName.format fmt e | TOption t -> Format.fprintf fmt "@[option %a@]" (format_typ_with_parens ~colors:(List.tl colors)) @@ -346,10 +349,12 @@ module Env = struct let rec module_env path env = match path with | [] -> env - | (modname, mpos) :: path -> + | (modname, mpos) :: path -> ( match A.ModuleName.Map.find_opt modname env.modules with - | None -> Message.raise_spanned_error mpos "Module %a not found" A.ModuleName.format modname - | Some env -> module_env path env + | None -> + Message.raise_spanned_error mpos "Module %a not found" + A.ModuleName.format modname + | Some env -> module_env path env) let add v tau t = { t with vars = Var.Map.add v tau t.vars } let add_var v typ t = add v (ast_to_typ typ) t @@ -428,12 +433,12 @@ and typecheck_expr_top_down : | A.ELocation loc -> let ty_opt = match loc with - | DesugaredScopeVar {name;_} | ScopelangScopeVar {name} -> + | DesugaredScopeVar { name; _ } | ScopelangScopeVar { name } -> Env.get_scope_var env (Mark.remove name) - | SubScopeVar {path; scope; var; _} -> + | SubScopeVar { path; scope; var; _ } -> let env = Env.module_env path env in Env.get_subscope_out_var env scope (Mark.remove var) - | ToplevelVar {path; name} -> + | ToplevelVar { path; name } -> let env = Env.module_env path env in Env.get_toplevel_var env (Mark.remove name) in @@ -447,12 +452,8 @@ and typecheck_expr_top_down : Expr.elocation loc (mark_with_tau_and_unify (ast_to_typ ty)) | A.EStruct { name; fields } -> let mark = ty_mark (TStruct name) in - let _path, str_ast = - A.StructName.Map.find name ctx.A.ctx_structs - in - let str = - A.StructName.Map.find name env.structs - in + let _path, str_ast = A.StructName.Map.find name ctx.A.ctx_structs in + let str = A.StructName.Map.find name env.structs in let _check_fields : unit = let missing_fields, extra_fields = A.StructField.Map.fold @@ -627,14 +628,11 @@ and typecheck_expr_top_down : in Expr.ematch ~e:e1' ~name ~cases mark | A.EMatch { e = e1; name; cases } -> - let _path, cases_ty = - A.EnumName.Map.find name ctx.A.ctx_enums - in + let _path, cases_ty = A.EnumName.Map.find name ctx.A.ctx_enums in let t_ret = unionfind ~pos:e1 (TAny (Any.fresh ())) in let mark = mark_with_tau_and_unify t_ret in let e1' = - typecheck_expr_top_down ~leave_unresolved ctx env - (unionfind (TEnum name)) + typecheck_expr_top_down ~leave_unresolved ctx env (unionfind (TEnum name)) e1 in let cases = @@ -683,28 +681,29 @@ and typecheck_expr_top_down : "Variable %s not found in the current context" (Bindlib.name_of v) in Expr.evar (Var.translate v) (mark_with_tau_and_unify tau') - | A.EExternal {path; name} -> + | A.EExternal { path; name } -> let ctx = Program.module_ctx ctx path in let ty = let not_found pr x = Message.raise_spanned_error pos_e - "Could not resolve the reference to %a%a.@ Make sure the corresponding \ - module was properly loaded?" - Print.path path - pr x + "Could not resolve the reference to %a%a.@ Make sure the \ + corresponding module was properly loaded?" + Print.path path pr x in match Mark.remove name with - | A.External_value name -> - (try - ast_to_typ (A.TopdefName.Map.find name ctx.ctx_topdefs) - with A.TopdefName.Map.Not_found _ -> not_found A.TopdefName.format name) - | A.External_scope name -> - (try - let scope_info = A.ScopeName.Map.find name ctx.ctx_scopes in - ast_to_typ (TArrow ([TStruct scope_info.in_struct_name, pos_e], - (TStruct scope_info.out_struct_name, pos_e)), - pos_e) - with A.ScopeName.Map.Not_found _ -> not_found A.ScopeName.format name) + | A.External_value name -> ( + try ast_to_typ (A.TopdefName.Map.find name ctx.ctx_topdefs) + with A.TopdefName.Map.Not_found _ -> + not_found A.TopdefName.format name) + | A.External_scope name -> ( + try + let scope_info = A.ScopeName.Map.find name ctx.ctx_scopes in + ast_to_typ + ( TArrow + ( [TStruct scope_info.in_struct_name, pos_e], + (TStruct scope_info.out_struct_name, pos_e) ), + pos_e ) + with A.ScopeName.Map.Not_found _ -> not_found A.ScopeName.format name) in Expr.eexternal ~path ~name (mark_with_tau_and_unify ty) | A.ELit lit -> Expr.elit lit (ty_mark (lit_type lit)) @@ -1033,30 +1032,30 @@ let program ~leave_unresolved prg = ctx_structs = A.StructName.Map.mapi (fun s_name (path, fields) -> - path, - A.StructField.Map.mapi - (fun f_name (t : A.typ) -> - match Mark.remove t with - | TAny -> - typ_to_ast ~leave_unresolved - (A.StructField.Map.find f_name - (A.StructName.Map.find s_name new_env.structs)) - | _ -> t) - fields) + ( path, + A.StructField.Map.mapi + (fun f_name (t : A.typ) -> + match Mark.remove t with + | TAny -> + typ_to_ast ~leave_unresolved + (A.StructField.Map.find f_name + (A.StructName.Map.find s_name new_env.structs)) + | _ -> t) + fields )) prg.decl_ctx.ctx_structs; ctx_enums = A.EnumName.Map.mapi (fun e_name (path, cons) -> - path, - A.EnumConstructor.Map.mapi - (fun cons_name (t : A.typ) -> - match Mark.remove t with - | TAny -> - typ_to_ast ~leave_unresolved - (A.EnumConstructor.Map.find cons_name - (A.EnumName.Map.find e_name new_env.enums)) - | _ -> t) - cons) + ( path, + A.EnumConstructor.Map.mapi + (fun cons_name (t : A.typ) -> + match Mark.remove t with + | TAny -> + typ_to_ast ~leave_unresolved + (A.EnumConstructor.Map.find cons_name + (A.EnumName.Map.find e_name new_env.enums)) + | _ -> t) + cons )) prg.decl_ctx.ctx_enums; }; } diff --git a/compiler/shared_ast/var.ml b/compiler/shared_ast/var.ml index d3f33e3c..8f889f98 100644 --- a/compiler/shared_ast/var.ml +++ b/compiler/shared_ast/var.ml @@ -92,6 +92,7 @@ module Map = struct open M type k0 = M.key + exception Not_found = M.Not_found type nonrec ('e, 'x) t = 'x t diff --git a/compiler/shared_ast/var.mli b/compiler/shared_ast/var.mli index b6257242..2f321cfc 100644 --- a/compiler/shared_ast/var.mli +++ b/compiler/shared_ast/var.mli @@ -57,8 +57,8 @@ end Extend as needed *) module Map : sig type ('e, 'x) t - type k0 + exception Not_found of k0 val empty : ('e, 'x) t diff --git a/compiler/surface/ast.ml b/compiler/surface/ast.ml index e1f5b890..6bdc8e9f 100644 --- a/compiler/surface/ast.ml +++ b/compiler/surface/ast.ml @@ -310,7 +310,8 @@ and law_structure = | CodeBlock of code_block * source_repr * bool (* Metadata if true *) and interface = code_block -(** Invariant: an interface shall only contain [*Decl] elements, or [Topdef] elements with [topdef_expr = None] *) +(** Invariant: an interface shall only contain [*Decl] elements, or [Topdef] + elements with [topdef_expr = None] *) and program = { program_items : law_structure list; diff --git a/compiler/surface/parser_driver.ml b/compiler/surface/parser_driver.ml index 055979e6..5e37c614 100644 --- a/compiler/surface/parser_driver.ml +++ b/compiler/surface/parser_driver.ml @@ -231,7 +231,7 @@ let rec parse_source_file { program_items = program.Ast.program_items; program_source_files = source_file_name :: program.Ast.program_source_files; - program_modules = [] + program_modules = []; } (** Expands the include directives in a parsing result, thus parsing new source @@ -267,8 +267,7 @@ and expand_includes Ast.program_source_files = acc.Ast.program_source_files @ new_sources; Ast.program_items = acc.Ast.program_items @ [Ast.LawHeading (heading, commands')]; - Ast.program_modules = - acc.Ast.program_modules @ new_modules; + Ast.program_modules = acc.Ast.program_modules @ new_modules; } | i -> { acc with Ast.program_items = acc.Ast.program_items @ [i] }) { @@ -302,8 +301,7 @@ let get_interface program = (** {1 API} *) let load_interface source_file language = - parse_source_file source_file language - |> get_interface + parse_source_file source_file language |> get_interface let parse_top_level_file (source_file : Cli.input_file) diff --git a/compiler/surface/parser_driver.mli b/compiler/surface/parser_driver.mli index 0819446f..68d4dba3 100644 --- a/compiler/surface/parser_driver.mli +++ b/compiler/surface/parser_driver.mli @@ -19,11 +19,10 @@ open Catala_utils -val load_interface : - Cli.input_file -> - Cli.backend_lang -> - Ast.interface -(** Reads only declarations in metadata in the supplied input file, and only keeps type information *) +val load_interface : Cli.input_file -> Cli.backend_lang -> Ast.interface +(** Reads only declarations in metadata in the supplied input file, and only + keeps type information *) val parse_top_level_file : Cli.input_file -> Cli.backend_lang -> Ast.program -(** Parses a catala file (handling file includes) and returns a program. Modules in the program are returned empty, use [load_interface] to fill them. *) +(** Parses a catala file (handling file includes) and returns a program. Modules + in the program are returned empty, use [load_interface] to fill them. *) diff --git a/compiler/verification/z3backend.real.ml b/compiler/verification/z3backend.real.ml index 329128fd..a3ab4f5c 100644 --- a/compiler/verification/z3backend.real.ml +++ b/compiler/verification/z3backend.real.ml @@ -667,11 +667,7 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr = accesses *) let accessors = List.hd (Datatype.get_accessors z3_struct) in let _path, fields = StructName.Map.find name ctx.ctx_decl.ctx_structs in - let idx_mappings = - List.combine - (StructField.Map.keys fields) - accessors - in + let idx_mappings = List.combine (StructField.Map.keys fields) accessors in let _, accessor = List.find (fun (field1, _) -> StructField.equal field field1) idx_mappings in @@ -687,11 +683,7 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr = let ctrs = Datatype.get_constructors z3_enum in let _path, cons_map = EnumName.Map.find name ctx.ctx_decl.ctx_enums in (* This should always succeed if the expression is well-typed in dcalc *) - let idx_mappings = - List.combine - (EnumConstructor.Map.keys cons_map) - ctrs - in + let idx_mappings = List.combine (EnumConstructor.Map.keys cons_map) ctrs in let _, ctr = List.find (fun (cons1, _) -> EnumConstructor.equal cons cons1)