diff --git a/compiler/catala_utils/cli.ml b/compiler/catala_utils/cli.ml index f975bc2b..3c88b006 100644 --- a/compiler/catala_utils/cli.ml +++ b/compiler/catala_utils/cli.ml @@ -39,7 +39,6 @@ type options = { mutable message_format : message_format_enum; mutable trace : bool; mutable plugins_dirs : string list; - mutable build_dir : string option; mutable disable_warnings : bool; mutable max_prec_digits : int; } @@ -58,7 +57,6 @@ let globals = message_format = Human; trace = false; plugins_dirs = []; - build_dir = None; disable_warnings = false; max_prec_digits = 20; } @@ -71,7 +69,6 @@ let enforce_globals ?message_format ?trace ?plugins_dirs - ?build_dir ?disable_warnings ?max_prec_digits () = @@ -82,7 +79,6 @@ let enforce_globals Option.iter (fun x -> globals.message_format <- x) message_format; Option.iter (fun x -> globals.trace <- x) trace; Option.iter (fun x -> globals.plugins_dirs <- x) plugins_dirs; - Option.iter (fun x -> globals.build_dir <- x) build_dir; Option.iter (fun x -> globals.disable_warnings <- x) disable_warnings; Option.iter (fun x -> globals.max_prec_digits <- x) max_prec_digits; globals @@ -112,7 +108,7 @@ module Flags = struct | _ -> assert false ) in required - & pos 0 (some converter) None + & pos ~rev:true 0 (some converter) None & Arg.info [] ~docv:"FILE" ~docs:Manpage.s_arguments ~doc:"Catala master file to be compiled." @@ -191,15 +187,6 @@ module Flags = struct in value & opt_all string default & info ["plugin-dir"] ~docv:"DIR" ~env ~doc - let build_dir = - value - & opt (some string) None - & info ["build-dir"] ~docv:"DIR" - ~doc: - "Directory where build artefacts are expected to be found. This \ - doesn't affect outptuts, but is used when looking up compiled \ - modules." - let disable_warnings = value & flag @@ -223,14 +210,13 @@ module Flags = struct message_format trace plugins_dirs - build_dir disable_warnings max_prec_digits : options = if debug then Printexc.record_backtrace true; (* This sets some global refs for convenience, but most importantly returns the options record. *) enforce_globals ~language ~debug ~color ~message_format ~trace - ~plugins_dirs ~build_dir ~disable_warnings ~max_prec_digits () + ~plugins_dirs ~disable_warnings ~max_prec_digits () in Term.( const make @@ -240,7 +226,6 @@ module Flags = struct $ message_format $ trace $ plugins_dirs - $ build_dir $ disable_warnings $ max_prec_digits) @@ -253,6 +238,13 @@ module Flags = struct Term.(const make $ input_file $ flags) end + let include_dirs = + value + & opt_all string [] + & info ["I";"include"] ~docv:"DIR" + ~doc: + "Include directory to lookup for compiled module files." + let check_invariants = value & flag @@ -314,17 +306,6 @@ module Flags = struct "Performs closure conversion on the lambda calculus. Implies \ $(b,--avoid-exceptions) and $(b,--optimize)." - let link_modules = - value - & opt_all file [] - & info ["use"; "u"] ~docv:"FILE" - ~doc: - "Specifies an additional module to be linked to the Catala program. \ - $(i,FILE) must be a catala file with a metadata section expressing \ - what is exported ; for interpretation, a compiled OCaml shared \ - module by the same basename (either .cmo or .cmxs) will be \ - expected." - let disable_counterexamples = value & flag @@ -334,6 +315,14 @@ module Flags = struct "Disables the search for counterexamples. Useful when you want a \ deterministic output from the Catala compiler, since provers can \ have some randomness in them." + + let build_dirs = + value + & opt_all string ["."; "_build"] + & info ["build-dir"] ~docv:"DIR" + ~env:(Cmd.Env.info "CATALA_BUILD_DIR") + ~doc: + "Directory where compiled modules are expected to be found (this option does not affect catala outputs)" end (* Retrieve current version from dune *) diff --git a/compiler/catala_utils/cli.mli b/compiler/catala_utils/cli.mli index 4aca040d..61fd688f 100644 --- a/compiler/catala_utils/cli.mli +++ b/compiler/catala_utils/cli.mli @@ -43,7 +43,6 @@ type options = private { mutable message_format : message_format_enum; mutable trace : bool; mutable plugins_dirs : string list; - mutable build_dir : string option; mutable disable_warnings : bool; mutable max_prec_digits : int; } @@ -63,7 +62,6 @@ val enforce_globals : ?message_format:message_format_enum -> ?trace:bool -> ?plugins_dirs:string list -> - ?build_dir:string option -> ?disable_warnings:bool -> ?max_prec_digits:int -> unit -> @@ -101,8 +99,9 @@ module Flags : sig val optimize : bool Term.t val avoid_exceptions : bool Term.t val closure_conversion : bool Term.t - val link_modules : string list Term.t + val include_dirs : string list Term.t val disable_counterexamples : bool Term.t + val build_dirs : string list Term.t end (** {2 Command-line application} *) diff --git a/compiler/catala_utils/file.ml b/compiler/catala_utils/file.ml index 6566ad46..5839f21a 100644 --- a/compiler/catala_utils/file.ml +++ b/compiler/catala_utils/file.ml @@ -117,12 +117,27 @@ let check_directory d = if Sys.is_directory d then Some d else None with Unix.Unix_error _ | Sys_error _ -> None -let ( / ) = Filename.concat +let check_file f = + try if Sys.is_directory f then None else Some f + with Unix.Unix_error _ | Sys_error _ -> None + +let ( / ) a b = + if a = "" || a = Filename.current_dir_name then b + else Filename.concat a b let dirname = Filename.dirname let ( /../ ) a b = dirname a / b let ( -.- ) file ext = Filename.chop_extension file ^ "." ^ ext -let equal = String.equal -let compare = String.compare + +let path_to_list path = + String.split_on_char Filename.dir_sep.[0] path + |> List.filter (fun d -> d <> "") + +let equal a b = + String.equal (String.lowercase_ascii a) (String.lowercase_ascii b) + +let compare a b = + String.compare (String.lowercase_ascii a) (String.lowercase_ascii b) + let format ppf t = Format.fprintf ppf "\"@{%s@}\"" t module Set = Set.Make (struct @@ -131,6 +146,13 @@ module Set = Set.Make (struct let compare = compare end) +module Map = Map.Make (struct + type nonrec t = t + + let compare = compare + let format = format +end) + let scan_tree f t = let is_dir t = try Sys.is_directory t @@ -143,7 +165,7 @@ let scan_tree f t = Sys.readdir d |> Array.to_list |> List.filter not_hidden - |> (if d = "." then fun t -> t else List.map (fun t -> d / t)) + |> List.map (fun t -> d / t) |> do_files and do_files flist = let dirs, files = @@ -154,3 +176,42 @@ let scan_tree f t = (Seq.filter_map f (List.to_seq files)) in do_files [t] + +module Tree = struct + type path = t + + type item = F | D of t + and t = (path * item) Map.t Lazy.t + + let empty = lazy Map.empty + + let rec build path = lazy + (Array.fold_left + (fun m f -> + let path = path / f in + match Sys.is_directory path with + | true -> Map.add f (path, D (build path)) m + | false -> Map.add f (path, F) m + | exception Sys_error _ -> m) + Map.empty + (Sys.readdir path)) + + let subtree t path = + let rec aux t = function + | [] -> t + | dir :: path -> + match Map.find_opt dir (Lazy.force t) with + | Some (_, D sub) -> aux sub path + | Some (_, F) | None -> raise Not_found + in + aux t (path_to_list path) + + let lookup t path = + try + let t = subtree t (dirname path) in + match Map.find_opt (Filename.basename path) (Lazy.force t) with + | Some (path, F) -> Some path + | Some (_, D _) | None -> None + with Not_found -> None + +end diff --git a/compiler/catala_utils/file.mli b/compiler/catala_utils/file.mli index 644d06d0..f2025607 100644 --- a/compiler/catala_utils/file.mli +++ b/compiler/catala_utils/file.mli @@ -85,9 +85,12 @@ val check_directory : t -> t option (** Checks if the given directory exists and returns it normalised (as per [Unix.realpath]). *) +val check_file : t -> t option +(** Returns its argument if it exists and is a plain file, [None] otherwise. Does not do resolution like [check_directory]. *) + val ( / ) : t -> t -> t (** [Filename.concat]: Sugar to allow writing - [File.("some" / "relative" / "path")] *) + [File.("some" / "relative" / "path")]. As an exception, if the lhs is [.], returns the rhs unchanged. *) val dirname : t -> t (** [Filename.dirname], re-exported for convenience *) @@ -100,18 +103,42 @@ val ( -.- ) : t -> string -> t with the given one (which shouldn't contain a dot) *) val equal : t -> t -> bool -(** String comparison no fancy file resolution *) +(** Case-insensitive string comparison (no file resolution whatsoever) *) val compare : t -> t -> int -(** String comparison no fancy file resolution *) +(** Case-insensitive string comparison (no file resolution whatsoever) *) val format : Format.formatter -> t -> unit (** Formats a filename in a consistent style, with double-quotes and color (when the output supports) *) module Set : Set.S with type elt = t +module Map : Map.S with type key = t val scan_tree : (t -> 'a option) -> t -> 'a Seq.t (** Recursively scans a directory for files. Directories or files matching ".*" or "_*" are ignored. Unreadable files or subdirectories are ignored with a debug message. If [t] is a plain file, scan just that non-recursively. *) + +module Tree: sig + (** A lazy tree structure mirroring the filesystem ; uses the comparison from File, so paths are case-insensitive. *) + + type path = t (** Alias for [File.t] *) + + type item = + | F (** Plain file *) + | D of t (** Directory with subtree *) + and t = (path * item) Map.t Lazy.t + (** Contents of a directory, lazily loaded. The map keys are the basenames of the files and subdirectories, while the values contain the original path (with correct capitalisation) *) + + val empty: t + + val build: path -> t + (** Lazily builds a [Tree.path] from the files read at [path]. The names in the maps are qualified (i.e. they all start with ["path/"]) *) + + val subtree: t -> path -> t + (** Looks up a path within a lazy tree *) + + val lookup: t -> path -> path option + (** Checks if there is a matching plain file (case-insensitively) ; and returns its path with the correct case if so *) +end diff --git a/compiler/catala_web_interpreter.ml b/compiler/catala_web_interpreter.ml index baf74bc9..c22dc6a1 100644 --- a/compiler/catala_web_interpreter.ml +++ b/compiler/catala_web_interpreter.ml @@ -23,7 +23,7 @@ let () = ~language:(Some language) ~debug:false ~color:Never ~trace () in let prg, ctx, _type_order = - Passes.dcalc options ~link_modules:[] ~optimize:false + Passes.dcalc options ~includes:File.Tree.empty ~optimize:false ~check_invariants:false in Shared_ast.Interpreter.interpret_program_dcalc prg diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index 38c4faaf..2b57d2b5 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -1527,7 +1527,16 @@ let translate_program (ctxt : Name_resolution.context) (surface : S.program) : (fun prgm child -> process_structure prgm child) prgm children | S.CodeBlock (block, _, _) -> process_code_block ctxt prgm block - | S.LawInclude _ | S.LawText _ | S.ModuleUse _ | S.ModuleDef _ -> prgm + | S.ModuleDef ((name, pos) as mname) -> + let file = Filename.basename (Pos.get_file pos) in + if not File.(equal name (Filename.remove_extension file)) + then + Message.raise_spanned_error pos + "Module declared as %a, which does not match the file name %a" + ModuleName.format (ModuleName.of_string mname) + File.format file + else prgm + | S.LawInclude _ | S.LawText _ | S.ModuleUse _ -> prgm in let desugared = List.fold_left diff --git a/compiler/driver.ml b/compiler/driver.ml index 3da8d9e7..1075c763 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -42,63 +42,68 @@ let get_lang options file = @{%s@}, and @{--language@} was not specified" filename) -let load_module_interfaces options program files = - let module MS = ModuleName.Set in - let to_set intf_list = - MS.of_list - (List.map (fun (mname, _) -> ModuleName.of_string mname) intf_list) +let load_module_interfaces includes program = + (* Recurse into program modules, looking up files in [using] and loading them *) + let err_req_pos chain = + List.map (fun m -> Some "Module required from", ModuleName.pos m) chain + in + let find_module req_chain m = + let fname_base = ModuleName.to_string m in + let required_from_file = Pos.get_file (ModuleName.pos m) in + match + Option.to_list (File.check_file File.(required_from_file /../ fname_base)) @ + List.filter_map + (fun (ext, _) -> + File.Tree.lookup includes (fname_base ^ ext)) + extensions + with + | [] -> + Message.raise_multispanned_error (err_req_pos (m::req_chain)) + "Required module not found: %a" + ModuleName.format m + | [f] -> + f + | ms -> + Message.raise_multispanned_error + (err_req_pos (m::req_chain)) + "Required module %a matches multiple files: %a" + ModuleName.format m + (Format.pp_print_list ~pp_sep:Format.pp_print_space File.format) + ms in - let used_modules = to_set program.Surface.Ast.program_modules in let load_file f = - let lang = get_lang options (FileName f) in + let lang = + List.assoc + (List.assoc (Filename.extension f) extensions) + Cli.languages + in let (mname, intf), using = - Surface.Parser_driver.load_interface (FileName f) lang + Surface.Parser_driver.load_interface (Cli.FileName f) lang in (ModuleName.of_string mname, intf), using in - let module_interfaces = List.map load_file files in - let rec check (required, acc) interfaces = - let required, acc, remaining = - List.fold_left - (fun (required, acc, skipped) (((modname, intf), using) as modl) -> - if MS.mem modname required then - let required = - List.fold_left - (fun req m -> MS.add (ModuleName.of_string m) req) - required using - in - required, ((modname :> string Mark.pos), intf) :: acc, skipped - else required, acc, modl :: skipped) - (required, acc, []) interfaces - in - if List.length remaining < List.length interfaces then - (* Loop until fixpoint *) - check (required, acc) remaining - else required, acc, remaining + let rec aux req_chain acc modules = + List.fold_left (fun acc mname -> + let m = ModuleName.of_string mname in + if List.mem_assoc m acc then acc else + let f = find_module req_chain m in + let (m', intf), using = load_file f in + if not (ModuleName.equal m m') then + Message.raise_multispanned_error + ((Some "Module name declaration", ModuleName.pos m') :: + err_req_pos (m::req_chain)) + "Mismatching module name declaration:"; + let acc = (m', intf) :: acc in + aux (m::req_chain) acc using + ) + acc modules in - let required, loaded, unused = check (used_modules, []) module_interfaces in - let missing = - MS.diff required - (MS.of_list (List.map (fun (m, _) -> ModuleName.of_string m) loaded)) + let program_modules = + aux [] [] (List.map fst program.Surface.Ast.program_modules) + |> List.map (fun (m, i) -> (m : ModuleName.t :> string Mark.pos), i) in - if (not (MS.is_empty missing)) || unused <> [] then - Message.raise_multispanned_error - (List.map - (fun m -> - ( Some - (Format.asprintf "Required module not found: %a" - ModuleName.format m), - ModuleName.pos m )) - (ModuleName.Set.elements missing) - @ List.map - (fun ((m, _), _) -> - ( Some - (Format.asprintf "No use was found for this module: %a" - ModuleName.format m), - ModuleName.pos m )) - unused) - "Modules used from the program don't match the command-line"; - loaded + { program with + Surface.Ast.program_modules } module Passes = struct (* Each pass takes only its cli options, then calls upon its dependent passes @@ -108,19 +113,19 @@ module Passes = struct Message.emit_debug "@{=@} @{%s@} @{=@}" (String.uppercase_ascii s) - let surface options ~link_modules : Surface.Ast.program * Cli.backend_lang = + let surface options ~includes : Surface.Ast.program * Cli.backend_lang = debug_pass_name "surface"; let language = get_lang options options.input_file in let prg = Surface.Parser_driver.parse_top_level_file options.input_file language in let prg = Surface.Fill_positions.fill_pos_with_legislative_info prg in - let program_modules = load_module_interfaces options prg link_modules in - { prg with program_modules }, language + let prg = load_module_interfaces includes prg in + prg, language - let desugared options ~link_modules : + let desugared options ~includes : Desugared.Ast.program * Desugared.Name_resolution.context = - let prg, _ = surface options ~link_modules in + let prg, _ = surface options ~includes in debug_pass_name "desugared"; Message.emit_debug "Name resolution..."; let ctx = Desugared.Name_resolution.form_context prg in @@ -138,12 +143,12 @@ module Passes = struct uids from strings. Maybe a reduced form should be included directly in [prg] for that purpose *) - let scopelang options ~link_modules : + let scopelang options ~includes : untyped Scopelang.Ast.program * Desugared.Name_resolution.context * Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t = - let prg, ctx = desugared options ~link_modules in + let prg, ctx = desugared options ~includes in debug_pass_name "scopelang"; let exceptions_graphs = Scopelang.From_desugared.build_exceptions_graph prg @@ -153,11 +158,11 @@ module Passes = struct in prg, ctx, exceptions_graphs - let dcalc options ~link_modules ~optimize ~check_invariants : + let dcalc options ~includes ~optimize ~check_invariants : typed Dcalc.Ast.program * Desugared.Name_resolution.context * Scopelang.Dependency.TVertex.t list = - let prg, ctx, _ = scopelang options ~link_modules in + let prg, ctx, _ = scopelang options ~includes in debug_pass_name "dcalc"; let type_ordering = Scopelang.Dependency.check_type_cycles prg.program_ctx.ctx_structs @@ -193,7 +198,7 @@ module Passes = struct let lcalc options - ~link_modules + ~includes ~optimize ~check_invariants ~avoid_exceptions @@ -202,7 +207,7 @@ module Passes = struct * Desugared.Name_resolution.context * Scopelang.Dependency.TVertex.t list = let prg, ctx, type_ordering = - dcalc options ~link_modules ~optimize ~check_invariants + dcalc options ~includes ~optimize ~check_invariants in debug_pass_name "lcalc"; let avoid_exceptions = avoid_exceptions || closure_conversion in @@ -243,7 +248,7 @@ module Passes = struct let scalc options - ~link_modules + ~includes ~optimize ~check_invariants ~avoid_exceptions @@ -252,7 +257,7 @@ module Passes = struct * Desugared.Name_resolution.context * Scopelang.Dependency.TVertex.t list = let prg, ctx, type_ordering = - lcalc options ~link_modules ~optimize ~check_invariants ~avoid_exceptions + lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in debug_pass_name "scalc"; @@ -353,6 +358,14 @@ module Commands = struct second_part first_part ScopeName.format scope_uid) second_part ) + let include_flags = + let mk dirs = + lazy (dirs + |> List.map (fun d -> Lazy.force (File.Tree.build d)) + |> List.fold_left (File.Map.union (fun _ x _ -> Some x)) File.Map.empty) + in + Term.(const mk $ Cli.Flags.include_dirs) + let get_output ?ext options output_file = File.get_out_channel ~source_file:options.Cli.input_file ~output_file ?ext () @@ -362,7 +375,7 @@ module Commands = struct ~output_file ?ext () let makefile options output = - let prg, _ = Passes.surface options ~link_modules:[] in + let prg, _ = Passes.surface options ~includes:File.Tree.empty in let backend_extensions_list = [".tex"] in let source_file = match options.Cli.input_file with @@ -393,7 +406,7 @@ module Commands = struct Term.(const makefile $ Cli.Flags.Global.options $ Cli.Flags.output) let html options output print_only_law wrap_weaved_output = - let prg, language = Passes.surface options ~link_modules:[] in + let prg, language = Passes.surface options ~includes:File.Tree.empty in Message.emit_debug "Weaving literate program into HTML"; let output_file, with_output = get_output_format options ~ext:".html" output @@ -421,7 +434,7 @@ module Commands = struct $ Cli.Flags.wrap_weaved_output) let latex options output print_only_law wrap_weaved_output = - let prg, language = Passes.surface options ~link_modules:[] in + let prg, language = Passes.surface options ~includes:File.Tree.empty in Message.emit_debug "Weaving literate program into LaTeX"; let output_file, with_output = get_output_format options ~ext:".tex" output @@ -448,8 +461,8 @@ module Commands = struct $ Cli.Flags.print_only_law $ Cli.Flags.wrap_weaved_output) - let exceptions options link_modules ex_scope ex_variable = - let _, ctxt, exceptions_graphs = Passes.scopelang options ~link_modules in + let exceptions options includes ex_scope ex_variable = + let _, ctxt, exceptions_graphs = Passes.scopelang options ~includes in let scope_uid = get_scope_uid ctxt ex_scope in let variable_uid = get_variable_uid ctxt scope_uid ex_variable in Desugared.Print.print_exceptions_graph scope_uid variable_uid @@ -467,12 +480,12 @@ module Commands = struct Term.( const exceptions $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.ex_scope $ Cli.Flags.ex_variable) - let scopelang options link_modules output ex_scope_opt = - let prg, ctx, _ = Passes.scopelang options ~link_modules in + let scopelang options includes output ex_scope_opt = + let prg, ctx, _ = Passes.scopelang options ~includes in let _output_file, with_output = get_output_format options output in with_output @@ fun fmt -> @@ -496,12 +509,12 @@ module Commands = struct Term.( const scopelang $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.output $ Cli.Flags.ex_scope_opt) - let typecheck options link_modules = - let prg, _, _ = Passes.scopelang options ~link_modules in + let typecheck options includes = + let prg, _, _ = Passes.scopelang options ~includes in Message.emit_debug "Typechecking..."; let _type_ordering = Scopelang.Dependency.check_type_cycles prg.program_ctx.ctx_structs @@ -519,11 +532,11 @@ module Commands = struct Cmd.v (Cmd.info "typecheck" ~doc:"Parses and typechecks a Catala program, without interpreting it.") - Term.(const typecheck $ Cli.Flags.Global.options $ Cli.Flags.link_modules) + Term.(const typecheck $ Cli.Flags.Global.options $ include_flags) - let dcalc options link_modules output optimize ex_scope_opt check_invariants = + let dcalc options includes output optimize ex_scope_opt check_invariants = let prg, ctx, _ = - Passes.dcalc options ~link_modules ~optimize ~check_invariants + Passes.dcalc options ~includes ~optimize ~check_invariants in let _output_file, with_output = get_output_format options output in with_output @@ -560,7 +573,7 @@ module Commands = struct Term.( const dcalc $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.output $ Cli.Flags.optimize $ Cli.Flags.ex_scope_opt @@ -568,13 +581,13 @@ module Commands = struct let proof options - link_modules + includes optimize ex_scope_opt check_invariants disable_counterexamples = let prg, ctx, _ = - Passes.dcalc options ~link_modules ~optimize ~check_invariants + Passes.dcalc options ~includes ~optimize ~check_invariants in Verification.Globals.setup ~optimize ~disable_counterexamples; let vcs = @@ -592,7 +605,7 @@ module Commands = struct Term.( const proof $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.optimize $ Cli.Flags.ex_scope_opt $ Cli.Flags.check_invariants @@ -621,11 +634,11 @@ module Commands = struct result) results - let interpret_dcalc options link_modules optimize check_invariants ex_scope = + let interpret_dcalc options includes optimize check_invariants build_dirs ex_scope = let prg, ctx, _ = - Passes.dcalc options ~link_modules ~optimize ~check_invariants + Passes.dcalc options ~includes ~optimize ~check_invariants in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules ~build_dirs prg; print_interpretation_results options Interpreter.interpret_program_dcalc prg (get_scope_uid ctx ex_scope) @@ -639,14 +652,15 @@ module Commands = struct Term.( const interpret_dcalc $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.optimize $ Cli.Flags.check_invariants + $ Cli.Flags.build_dirs $ Cli.Flags.ex_scope) let lcalc options - link_modules + includes output optimize check_invariants @@ -654,7 +668,7 @@ module Commands = struct closure_conversion ex_scope_opt = let prg, ctx, _ = - Passes.lcalc options ~link_modules ~optimize ~check_invariants + Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in let _output_file, with_output = get_output_format options output in @@ -680,7 +694,7 @@ module Commands = struct Term.( const lcalc $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.output $ Cli.Flags.optimize $ Cli.Flags.check_invariants @@ -690,16 +704,18 @@ module Commands = struct let interpret_lcalc options - link_modules + includes optimize check_invariants avoid_exceptions closure_conversion + build_dirs ex_scope = let prg, ctx, _ = - Passes.lcalc options ~link_modules ~optimize ~check_invariants + Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in + Interpreter.load_runtime_modules ~build_dirs prg; print_interpretation_results options Interpreter.interpret_program_lcalc prg (get_scope_uid ctx ex_scope) @@ -713,23 +729,24 @@ module Commands = struct Term.( const interpret_lcalc $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.optimize $ Cli.Flags.check_invariants $ Cli.Flags.avoid_exceptions $ Cli.Flags.closure_conversion + $ Cli.Flags.build_dirs $ Cli.Flags.ex_scope) let ocaml options - link_modules + includes output optimize check_invariants avoid_exceptions closure_conversion = let prg, _, type_ordering = - Passes.lcalc options ~link_modules ~optimize ~check_invariants + Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in let output_file, with_output = @@ -749,7 +766,7 @@ module Commands = struct Term.( const ocaml $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.output $ Cli.Flags.optimize $ Cli.Flags.check_invariants @@ -758,7 +775,7 @@ module Commands = struct let scalc options - link_modules + includes output optimize check_invariants @@ -766,7 +783,7 @@ module Commands = struct closure_conversion ex_scope_opt = let prg, ctx, _ = - Passes.scalc options ~link_modules ~optimize ~check_invariants + Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in let _output_file, with_output = get_output_format options output in @@ -795,7 +812,7 @@ module Commands = struct Term.( const scalc $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.output $ Cli.Flags.optimize $ Cli.Flags.check_invariants @@ -805,14 +822,14 @@ module Commands = struct let python options - link_modules + includes output optimize check_invariants avoid_exceptions closure_conversion = let prg, _, type_ordering = - Passes.scalc options ~link_modules ~optimize ~check_invariants + Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in @@ -832,17 +849,17 @@ module Commands = struct Term.( const python $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.output $ Cli.Flags.optimize $ Cli.Flags.check_invariants $ Cli.Flags.avoid_exceptions $ Cli.Flags.closure_conversion) - let r options link_modules output optimize check_invariants closure_conversion + let r options includes output optimize check_invariants closure_conversion = let prg, _, type_ordering = - Passes.scalc options ~link_modules ~optimize ~check_invariants + Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions:false ~closure_conversion in @@ -858,7 +875,7 @@ module Commands = struct Term.( const r $ Cli.Flags.Global.options - $ Cli.Flags.link_modules + $ include_flags $ Cli.Flags.output $ Cli.Flags.optimize $ Cli.Flags.check_invariants diff --git a/compiler/driver.mli b/compiler/driver.mli index a66134cf..af28226f 100644 --- a/compiler/driver.mli +++ b/compiler/driver.mli @@ -27,24 +27,24 @@ val main : unit -> unit module Passes : sig val surface : Cli.options -> - link_modules:string list -> + includes:File.Tree.t -> Surface.Ast.program * Cli.backend_lang val desugared : Cli.options -> - link_modules:string list -> + includes:File.Tree.t -> Desugared.Ast.program * Desugared.Name_resolution.context val scopelang : Cli.options -> - link_modules:string list -> + includes:File.Tree.t -> Shared_ast.untyped Scopelang.Ast.program * Desugared.Name_resolution.context * Desugared.Dependency.ExceptionsDependencies.t Desugared.Ast.ScopeDef.Map.t val dcalc : Cli.options -> - link_modules:string list -> + includes:File.Tree.t -> optimize:bool -> check_invariants:bool -> Shared_ast.typed Dcalc.Ast.program @@ -53,7 +53,7 @@ module Passes : sig val lcalc : Cli.options -> - link_modules:string list -> + includes:File.Tree.t -> optimize:bool -> check_invariants:bool -> avoid_exceptions:bool -> @@ -64,7 +64,7 @@ module Passes : sig val scalc : Cli.options -> - link_modules:string list -> + includes:File.Tree.t -> optimize:bool -> check_invariants:bool -> avoid_exceptions:bool -> @@ -99,6 +99,8 @@ module Commands : sig string -> Desugared.Ast.ScopeDef.t + val include_flags : File.Tree.t Cmdliner.Term.t + val commands : unit Cmdliner.Cmd.t list (** The list of built-in catala subcommands, as expected by [Cmdliner.Cmd.group] *) diff --git a/compiler/plugins/api_web.ml b/compiler/plugins/api_web.ml index 8c4261bf..abb13350 100644 --- a/compiler/plugins/api_web.ml +++ b/compiler/plugins/api_web.ml @@ -428,7 +428,7 @@ module To_jsoo = struct end let run - link_modules + includes output optimize check_invariants @@ -438,7 +438,7 @@ let run if not options.Cli.trace then Message.raise_error "This plugin requires the --trace flag."; let prg, _, type_ordering = - Driver.Passes.lcalc options ~link_modules ~optimize ~check_invariants + Driver.Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in let modname = @@ -472,7 +472,7 @@ let run let term = let open Cmdliner.Term in const run - $ Cli.Flags.link_modules + $ Driver.Commands.include_flags $ Cli.Flags.output $ Cli.Flags.optimize $ Cli.Flags.check_invariants diff --git a/compiler/plugins/dune b/compiler/plugins/dune index 4f84c8e6..faeb5e19 100644 --- a/compiler/plugins/dune +++ b/compiler/plugins/dune @@ -38,13 +38,6 @@ (flags (-linkall)) (libraries shared_ast catala.driver ocamlgraph)) -(library - (name modules) - (public_name catala.plugins.modules) - (synopsis "Catala plugin for experimental module handling tooling") - (modules modules) - (libraries shared_ast catala.driver)) - (documentation (package catala) (mld_files plugins)) diff --git a/compiler/plugins/explain.ml b/compiler/plugins/explain.ml index dd3df927..ae5eba38 100644 --- a/compiler/plugins/explain.ml +++ b/compiler/plugins/explain.ml @@ -1385,12 +1385,12 @@ let options = $ Cli.Flags.output $ base_src_url) -let run link_modules optimize ex_scope explain_options global_options = +let run includes optimize ex_scope explain_options global_options = let prg, ctx, _ = - Driver.Passes.dcalc global_options ~link_modules ~optimize + Driver.Passes.dcalc global_options ~includes ~optimize ~check_invariants:false in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules ~includes prg; let scope = Driver.Commands.get_scope_uid ctx ex_scope in (* let result_expr, env = interpret_program prg scope in *) let g, base_vars, env = program_to_graph explain_options prg scope in @@ -1436,7 +1436,7 @@ let run link_modules optimize ex_scope explain_options global_options = let term = let open Cmdliner.Term in const run - $ Cli.Flags.link_modules + $ Driver.Commands.include_flags $ Cli.Flags.optimize $ Cli.Flags.ex_scope $ options diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index aae3d97e..2959350a 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -206,7 +206,7 @@ module To_json = struct end let run - link_modules + includes output optimize check_invariants @@ -215,7 +215,7 @@ let run ex_scope options = let prg, ctx, _ = - Driver.Passes.lcalc options ~link_modules ~optimize ~check_invariants + Driver.Passes.lcalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in let output_file, with_output = @@ -233,7 +233,7 @@ let run let term = let open Cmdliner.Term in const run - $ Cli.Flags.link_modules + $ Driver.Commands.include_flags $ Cli.Flags.output $ Cli.Flags.optimize $ Cli.Flags.check_invariants diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml index da114e55..368d7a81 100644 --- a/compiler/plugins/lazy_interp.ml +++ b/compiler/plugins/lazy_interp.ml @@ -257,11 +257,11 @@ let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : (* -- Plugin registration -- *) -let run link_modules optimize check_invariants ex_scope options = +let run includes optimize check_invariants ex_scope options = let prg, ctx, _ = - Driver.Passes.dcalc options ~link_modules ~optimize ~check_invariants + Driver.Passes.dcalc options ~includes ~optimize ~check_invariants in - Interpreter.load_runtime_modules prg; + Interpreter.load_runtime_modules ~includes prg; let scope = Driver.Commands.get_scope_uid ctx ex_scope in let result_expr, _env = interpret_program prg scope in let fmt = Format.std_formatter in @@ -270,7 +270,7 @@ let run link_modules optimize check_invariants ex_scope options = let term = let open Cmdliner.Term in const run - $ Cli.Flags.link_modules + $ Driver.Commands.include_flags $ Cli.Flags.optimize $ Cli.Flags.check_invariants $ Cli.Flags.ex_scope diff --git a/compiler/plugins/modules.ml b/compiler/plugins/modules.ml deleted file mode 100644 index ab2a711f..00000000 --- a/compiler/plugins/modules.ml +++ /dev/null @@ -1,233 +0,0 @@ -(* This file is part of the Catala compiler, a specification language for tax - and social benefits computation rules. Copyright (C) 2020 Inria, contributor: - Louis Gesbert . - - Licensed under the Apache License, Version 2.0 (the "License"); you may not - use this file except in compliance with the License. You may obtain a copy of - the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - License for the specific language governing permissions and limitations under - the License. *) - -open Catala_utils - -let action_flag = - let docs = "ACTIONS" in - let open Cmdliner.Arg in - required - & vflag None - [ - ( Some `Compile, - info ["compile"] ~docs - ~doc: - "Compiles a Catala file into a module: a $(b,.cmxs) file that \ - can be used by the Catala interpreter." - (* "and $(b,cmo) and $(b,cmx) files that can be linked into an OCaml - program" *) ); - ( Some `Link, - info ["link"] ~docs - ~doc: - "Compiles and links a catala program into a binary (using the \ - ocaml backend). Specify a main scope using the $(b,--scope) \ - flag to be run upon execution. This is still pretty useless at \ - the moment besides for testing purposes, as there is no way to \ - feed input to the generated program, and the output will be \ - silent. Assertions will be checked, though." ); - ] - -let gen_ocaml options link_modules optimize check_invariants main = - let prg, ctx, type_ordering = - Driver.Passes.lcalc options ~link_modules ~optimize ~check_invariants - ~avoid_exceptions:false ~closure_conversion:false - in - let exec_scope = Option.map (Driver.Commands.get_scope_uid ctx) main in - let filename, with_output = - Driver.Commands.get_output_format options ~ext:".ml" None - in - with_output - @@ fun ppf -> - Lcalc.To_ocaml.format_program ppf ?exec_scope prg type_ordering; - Option.get filename - -let run_process cmd args = - Message.emit_debug "Running @[@{@{%s@} %a@}@}@]" cmd - (Format.pp_print_list ~pp_sep:Format.pp_print_space Format.pp_print_string) - args; - match - Unix.waitpid [] - (Unix.create_process cmd - (Array.of_list (cmd :: args)) - Unix.stdin Unix.stdout Unix.stderr) - with - | _, Unix.WEXITED 0 -> () - | _, _ -> Message.raise_error "Child process @{%s@} failed" cmd - -let with_flag flag args = - List.fold_right (fun p acc -> flag :: p :: acc) args [] - -let ocaml_libdir = - lazy - (try String.trim (File.process_out "opam" ["var"; "lib"]) - with Failure _ -> ( - try String.trim (File.process_out "ocamlc" ["-where"]) - with Failure _ -> ( - match File.(check_directory (dirname Sys.argv.(0) /../ "lib")) with - | Some d -> d - | None -> - Message.raise_error - "Could not locate the OCaml library directory, make sure OCaml or \ - opam is installed"))) - -let rec find_catala_project_root dir = - if Sys.file_exists File.(dir / "catala.opam") then Some dir - else - let dir' = File.dirname dir in - if dir' = dir then None else find_catala_project_root dir' - -let runtime_dir = - lazy - (let d = - match find_catala_project_root (Sys.getcwd ()) with - | Some root -> - (* Relative dir when running from catala source *) - File.( - root - / "_build" - / "install" - / "default" - / "lib" - / "catala" - / "runtime_ocaml") - | None -> ( - match - File.check_directory - File.(dirname Sys.argv.(0) /../ "lib" / "catala" / "runtime_ocaml") - with - | Some d -> d - | None -> File.(Lazy.force ocaml_libdir / "catala" / "runtime")) - in - match File.check_directory d with - | Some dir -> - Message.emit_debug "Catala runtime libraries found at @{%s@}." dir; - dir - | None -> - Message.raise_error - "@[Could not locate the Catala runtime library.@ Make sure that \ - either catala is correctly installed,@ or you are running from the \ - root of a compiled source tree.@]") - -let compile options link_modules optimize check_invariants = - let modname = - match options.Cli.input_file with - (* TODO: extract module name from directives *) - | FileName n -> Driver.modname_of_file n - | _ -> Message.raise_error "Input must be a file name for this command" - in - let basename = String.uncapitalize_ascii modname in - let ml_file = gen_ocaml options link_modules optimize check_invariants None in - let flags = ["-I"; Lazy.force runtime_dir] in - let shared_out = File.((ml_file /../ basename) ^ ".cmxs") in - Message.emit_debug "Compiling OCaml shared object file @{%s@}..." - shared_out; - run_process "ocamlopt" ("-shared" :: ml_file :: "-o" :: shared_out :: flags); - (* let byte_out = basename ^ ".cmo" in - * Message.emit_debug "Compiling OCaml byte-code object file @{%s@}..." byte_out; - * run_process "ocamlc" ("-c" :: ml_file :: "-o" :: byte_out :: flags); - * let native_out = basename ^ ".cmx" in - * Message.emit_debug "Compiling OCaml native object file @{%s@}..." native_out; - * run_process "ocamlopt" ("-c" :: ml_file :: "-o" :: native_out ::flags); *) - Message.emit_debug "Done." - -let link options link_modules optimize check_invariants output ex_scope_opt = - let ml_file = - gen_ocaml options link_modules optimize check_invariants ex_scope_opt - in - (* NOTE: assuming native target at the moment *) - let cmd = "ocamlopt" in - let ocaml_libdir = Lazy.force ocaml_libdir in - let runtime_dir = Lazy.force runtime_dir in - (* Recursive dependencies are expanded manually here. A shorter version would - use [ocamlfind ocalmopt -linkpkg -package] with just ppx_yojson_conv_lib, - zarith and dates_calc *) - let link_libs = - [ - "biniou"; - "easy-format"; - "yojson"; - "ppx_yojson_conv_lib"; - "zarith"; - "dates_calc"; - ] - in - let link_libdirs = - List.map - (fun lib -> - match File.(check_directory (ocaml_libdir / lib)) with - | None -> - Message.raise_error - "Required OCaml library not found at @{%s@}.@ Try `opam \ - install %s'" - File.(ocaml_libdir / lib) - lib - | Some l -> l) - link_libs - in - let runtime_lib = File.(runtime_dir / "runtime_ocaml.cmxa") in - let modules = - List.map (fun m -> Filename.remove_extension m ^ ".cmx") link_modules - in - let output = - match output with - | Some o -> o - | None -> Filename.remove_extension ml_file ^ ".exe" - in - let args = - with_flag "-I" link_libdirs - @ with_flag "-I" - (List.sort_uniq compare (List.map Filename.dirname modules)) - @ List.map - (fun lib -> String.map (function '-' -> '_' | c -> c) lib ^ ".cmxa") - link_libs - @ ("-I" :: runtime_dir :: runtime_lib :: modules) - @ [ml_file; "-o"; output] - in - run_process cmd args; - Message.emit_result "Successfully generated @{%s@}" output -(* Compile from ml and link the modules cmx. => ocamlfind ocamlopt -linkpkg - -package ppx_yojson_conv_lib -package zarith -package dates_calc -I - _build/default/runtimes/ocaml/.runtime_ocaml.objs/byte - _build/default/runtimes/ocaml/runtime_ocaml.cmxa ext.cmx extuse.ml *) - -let run - action - link_modules - optimize - check_invariants - output - ex_scope_opt - options = - match action with - | `Compile -> compile options link_modules optimize check_invariants - | `Link -> - link options link_modules optimize check_invariants ex_scope_opt output - -let term = - let open Cmdliner.Term in - const run - $ action_flag - $ Cli.Flags.link_modules - $ Cli.Flags.optimize - $ Cli.Flags.check_invariants - $ Cli.Flags.ex_scope_opt - $ Cli.Flags.output - -let () = - Driver.Plugin.register "module" term - ~doc: - "This plugin provides a few experimental tools related to module \ - generation and compilation" diff --git a/compiler/plugins/python.ml b/compiler/plugins/python.ml index 4c15aacb..85ad1f97 100644 --- a/compiler/plugins/python.ml +++ b/compiler/plugins/python.ml @@ -23,7 +23,7 @@ open Catala_utils let run - link_modules + includes output optimize check_invariants @@ -32,7 +32,7 @@ let run options = let open Driver.Commands in let prg, _, type_ordering = - Driver.Passes.scalc options ~link_modules ~optimize ~check_invariants + Driver.Passes.scalc options ~includes ~optimize ~check_invariants ~avoid_exceptions ~closure_conversion in @@ -45,7 +45,7 @@ let run let term = let open Cmdliner.Term in const run - $ Cli.Flags.link_modules + $ Driver.Commands.include_flags $ Cli.Flags.output $ Cli.Flags.optimize $ Cli.Flags.check_invariants diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index a07f2d11..4d8867b2 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -940,29 +940,47 @@ let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list reflect that. *) let evaluate_expr ctx lang e = delcustom (evaluate_expr ctx lang (addcustom e)) -let load_runtime_modules prg = - match ModuleName.Map.keys prg.decl_ctx.ctx_modules with - | [] -> () - | modules -> - Message.emit_debug "Loading shared modules... %a" - (fun ppf -> ModuleName.Map.format_keys ppf) - prg.decl_ctx.ctx_modules; - List.iter - (fun m -> - let srcfile = Pos.get_file (ModuleName.pos m) in - let obj_file = - File.((srcfile /../ ModuleName.to_string m) ^ ".cmo") - |> Dynlink.adapt_filename - in - let obj_file = - match Cli.globals.build_dir with - | None -> obj_file - | Some d -> File.(d / obj_file) - in - try Dynlink.loadfile obj_file - with Dynlink.Error dl_err -> - Message.raise_error - "Could not load module %a, has it been suitably compiled?@;\ - <1 2>@[%a@]" ModuleName.format m Format.pp_print_text - (Dynlink.error_message dl_err)) - modules +let load_runtime_modules ~build_dirs prg = + let load m = + let obj_base = + Dynlink.adapt_filename File.(Pos.get_file (ModuleName.pos m) /../ ModuleName.to_string m ^ ".cmo") + in + let possible_files = List.map File.(fun d -> d / obj_base) build_dirs in + match List.filter Sys.file_exists possible_files with + | [] -> + Message.raise_spanned_error + ~span_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here") + (ModuleName.pos m) + "Compiled OCaml object %a not found. Make sure it has been suitably compiled, and use @{-I DIR@} if necessary." File.format obj_base + | [f] -> + (try Dynlink.loadfile f + with Dynlink.Error dl_err -> + Message.raise_error + "Error loading compiled module from %a:@;\ + <1 2>@[%a@]" File.format f + Format.pp_print_text + (Dynlink.error_message dl_err)) + | fs -> + Message.raise_spanned_error + ~span_msg:(fun ppf -> Format.pp_print_string ppf "Module defined here") + (ModuleName.pos m) + "@[Multiple compiled OCaml objects for %a found:@,- %a@]" + ModuleName.format m + (Format.pp_print_list ~pp_sep:(fun ppf () -> Format.fprintf ppf "@,- ") + File.format) + fs + in + let rec aux loaded decl_ctx = + ModuleName.Map.fold (fun mname sub_decl_ctx loaded -> + if ModuleName.Set.mem mname loaded then loaded else + let loaded = ModuleName.Set.add mname loaded in + let loaded = aux loaded sub_decl_ctx in + load mname; + loaded) + decl_ctx.ctx_modules loaded + in + Message.emit_debug "Loading shared modules... %a" + (fun ppf -> ModuleName.Map.format_keys ppf) + prg.decl_ctx.ctx_modules; + let (_loaded: ModuleName.Set.t) = aux ModuleName.Set.empty prg.decl_ctx in + () diff --git a/compiler/shared_ast/interpreter.mli b/compiler/shared_ast/interpreter.mli index d95cc387..2f2fcecc 100644 --- a/compiler/shared_ast/interpreter.mli +++ b/compiler/shared_ast/interpreter.mli @@ -72,8 +72,8 @@ val interpret_program_lcalc : providing for each argument a thunked empty default. Returns a list of all the computed values for the scope variables of the executed scope. *) -val load_runtime_modules : _ program -> unit +val load_runtime_modules : build_dirs:File.t list -> _ program -> unit (** Dynlink the runtime modules required by the given program, in order to make - them callable by the interpreter. If Cli.globals.build_dir is specified, the - runtime module names (as obtained by looking up the positions in the - program's module bindings) are assumed to be relative and looked up there. *) + them callable by the interpreter. + + The specified build dirs are used as prefixes to the catala files defining the modules: with {[["."; "_build"]]}, this means that the compiled artifact of [foo/bar.catala_en] will be searched in [foo/bar.cmxs] and [_build/foo/bar.cmxs] *) diff --git a/compiler/surface/parser_driver.ml b/compiler/surface/parser_driver.ml index d8b0c1a2..23dad967 100644 --- a/compiler/surface/parser_driver.ml +++ b/compiler/surface/parser_driver.ml @@ -273,7 +273,10 @@ and expand_includes match command with | Ast.ModuleDef id -> ( match acc.Ast.program_module_name with - | None -> { acc with Ast.program_module_name = Some id } + | None -> + { acc with Ast.program_module_name = Some id; + Ast.program_items = command :: acc.Ast.program_items; + } | Some id2 -> Message.raise_multispanned_error [None, Mark.get id; None, Mark.get id2]