From e224e87f71e6a87be456fcc08ab94266d028e8cb Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Wed, 19 Apr 2023 18:26:50 +0200 Subject: [PATCH] Wip support for modules (first working dynload test with compilation done by manual calls to ocaml) A few pieces of the puzzle: * Loading of interfaces only from Catala files * Registration of toplevel values in modules compiled to OCaml, to allow access using dynlink * Shady conversion from OCaml runtime values to/from Catala expressions, to allow interop (ffi) of compiled modules and the interpreter --- compiler/dcalc/from_scopelang.ml | 2 +- compiler/desugared/from_surface.ml | 23 +- compiler/desugared/name_resolution.mli | 1 - compiler/driver.ml | 30 +- compiler/lcalc/closure_conversion.ml | 6 +- compiler/lcalc/compile_with_exceptions.ml | 6 +- compiler/lcalc/compile_without_exceptions.ml | 1 + compiler/lcalc/to_ocaml.ml | 72 ++- compiler/lcalc/to_ocaml.mli | 4 +- compiler/plugins/lazy_interp.ml | 1 + compiler/scopelang/from_desugared.ml | 53 +-- compiler/shared_ast/definitions.ml | 36 +- compiler/shared_ast/expr.ml | 34 +- compiler/shared_ast/expr.mli | 8 + compiler/shared_ast/interpreter.ml | 409 +++++++++++++++++- compiler/shared_ast/interpreter.mli | 19 +- compiler/shared_ast/optimizations.ml | 32 +- compiler/shared_ast/print.ml | 4 + compiler/shared_ast/program.ml | 9 + compiler/shared_ast/program.mli | 4 + compiler/shared_ast/qident.ml | 53 +++ compiler/shared_ast/qident.mli | 36 ++ compiler/shared_ast/shared_ast.ml | 1 + compiler/shared_ast/typing.ml | 15 + compiler/surface/ast.ml | 2 + compiler/surface/parser_driver.ml | 40 +- compiler/surface/parser_driver.mli | 9 + compiler/verification/z3backend.real.ml | 1 + runtimes/ocaml/runtime.ml | 20 + runtimes/ocaml/runtime.mli | 18 + .../191_fix_record_name_confusion.catala_en | 7 + 31 files changed, 843 insertions(+), 113 deletions(-) create mode 100644 compiler/shared_ast/qident.ml create mode 100644 compiler/shared_ast/qident.mli diff --git a/compiler/dcalc/from_scopelang.ml b/compiler/dcalc/from_scopelang.ml index 63203c04..a710f87b 100644 --- a/compiler/dcalc/from_scopelang.ml +++ b/compiler/dcalc/from_scopelang.ml @@ -551,7 +551,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) : | EOp { op = Add_dat_dur _; tys } -> Expr.eop (Add_dat_dur ctx.date_rounding) tys m | EOp { op; tys } -> Expr.eop (Operator.translate op) tys m - | (EVar _ | EAbs _ | ELit _ | EStruct _ | EStructAccess _ | ETuple _ + | (EVar _ | EAbs _ | ELit _ | EExternal _ | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EEmptyError | EErrorOnEmpty _ | EArray _ | EIfThenElse _ ) as e -> Expr.map ~f:(translate_expr ctx) (e, m) diff --git a/compiler/desugared/from_surface.ml b/compiler/desugared/from_surface.ml index dbb20246..9351a425 100644 --- a/compiler/desugared/from_surface.ml +++ b/compiler/desugared/from_surface.ml @@ -360,8 +360,9 @@ let rec translate_expr | None -> Name_resolution.raise_unknown_identifier "for a local, scope-wide or global variable" (x, pos)))) - | Ident (_path, _x) -> - Message.raise_spanned_error pos "Qualified paths are not supported yet" + | Surface.Ast.Ident (path, x) -> + let path = List.map Mark.remove path in + Expr.eexternal (path, Mark.remove x) emark | Dotted (e, ((path, x), _ppos)) -> ( match path, Mark.remove e with | [], Ident ([], (y, _)) @@ -1044,8 +1045,8 @@ let process_def ExceptionToRule (name, pos)) | ExceptionToLabel label_str -> ( try - let label_id = Ident.Map.find (Mark.remove label_str) - scope_def_ctxt.label_idmap + let label_id = + Ident.Map.find (Mark.remove label_str) scope_def_ctxt.label_idmap in ExceptionToLabel (label_id, Mark.get label_str) with Not_found -> @@ -1412,6 +1413,7 @@ let translate_program }) ctxt.Name_resolution.scopes in + let translate_type t = Name_resolution.process_type ctxt t in { Ast.program_ctx = { @@ -1426,6 +1428,19 @@ let translate_program | _ -> acc) ctxt.Name_resolution.typedefs ScopeName.Map.empty; ctx_struct_fields = ctxt.Name_resolution.field_idmap; + ctx_modules = + List.fold_left + (fun map (path, def) -> + match def with + | ( Surface.Ast.Topdef + {topdef_name; topdef_type; _}, + _pos ) -> + Qident.Map.add (path, Mark.remove topdef_name) (translate_type topdef_type) map + | (ScopeDecl _ | StructDecl _ | EnumDecl _), _ (* as e *) -> + map + (* assert false (\* TODO *\) *) + | ScopeUse _, _ -> assert false) + Qident.Map.empty prgm.Surface.Ast.program_interfaces; }; Ast.program_topdefs = TopdefName.Map.empty; Ast.program_scopes; diff --git a/compiler/desugared/name_resolution.mli b/compiler/desugared/name_resolution.mli index f22c902f..bfb011b5 100644 --- a/compiler/desugared/name_resolution.mli +++ b/compiler/desugared/name_resolution.mli @@ -157,7 +157,6 @@ val get_scope : context -> Ident.t Mark.pos -> ScopeName.t val process_type : context -> Surface.Ast.typ -> typ (** Convert a surface base type to an AST type *) -(* Note: should probably be moved to a different module *) (** {1 API} *) diff --git a/compiler/driver.ml b/compiler/driver.ml index 058a4176..a4b24ba1 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -129,6 +129,9 @@ let get_variable_uid scope_uid) second_part ))) +let modname_of_file f = (* Fixme: make this more robust *) + String.capitalize_ascii Filename.(basename (remove_extension f)) + (** Entry function for the executable. Returns a negative number in case of error. Usage: [driver source_file options]*) let driver source_file (options : Cli.options) : int = @@ -189,6 +192,24 @@ let driver source_file (options : Cli.options) : int = Surface.Parser_driver.parse_top_level_file source_file language in let prgm = Surface.Fill_positions.fill_pos_with_legislative_info prgm in + let prgm = + (* FIXME: WIP placeholder *) + match Sys.getenv_opt "CATALA_INTF" with + | None | Some "" -> prgm + | Some str -> + let files = String.split_on_char ',' str in + List.fold_left + (fun prgm f -> + let lang = + Option.value ~default:Cli.En + @@ Option.bind + (List.assoc_opt (Filename.extension f) extensions) + (fun l -> List.assoc_opt l Cli.languages) + in + let modname = modname_of_file f in + Surface.Parser_driver.add_interface (FileName f) lang [modname] prgm) + prgm files + in let get_output ?ext = File.get_out_channel ~source_file ~output_file:options.output_file ?ext in @@ -490,7 +511,14 @@ let driver source_file (options : Cli.options) : int = Message.emit_debug "Compiling program into OCaml..."; Message.emit_debug "Writing to %s..." (Option.value ~default:"stdout" output_file); - Lcalc.To_ocaml.format_program fmt prgm type_ordering + let modname = + match source_file with + (* FIXME: WIP placeholder *) + | FileName n -> + Some (modname_of_file n) + | _ -> None + in + Lcalc.To_ocaml.format_program fmt ?modname prgm type_ordering | `Plugin (Plugin.Dcalc _) -> assert false | `Plugin (Plugin.Lcalc p) -> let output_file, _ = diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 72652c88..9f41bdba 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -38,7 +38,8 @@ let rec hoist_context_free_closures : let m = Mark.get e in match Mark.remove e with | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EArray _ - | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _ | ECatch _ | EVar _ -> + | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _ | ECatch _ | EVar _ + | EExternal _ -> Expr.map_gather ~acc:[] ~join:( @ ) ~f:(hoist_context_free_closures ctx) e | EMatch { e; cases; name } -> let collected_closures, new_e = (hoist_context_free_closures ctx) e in @@ -98,7 +99,8 @@ let rec transform_closures_expr : let m = Mark.get e in match Mark.remove e with | EStruct _ | EStructAccess _ | ETuple _ | ETupleAccess _ | EInj _ | EArray _ - | ELit _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _ | ECatch _ -> + | ELit _ | EExternal _ | EAssert _ | EOp _ | EIfThenElse _ | ERaise _ + | ECatch _ -> Expr.map_gather ~acc:Var.Set.empty ~join:Var.Set.union ~f:(transform_closures_expr ctx) e diff --git a/compiler/lcalc/compile_with_exceptions.ml b/compiler/lcalc/compile_with_exceptions.ml index fa36f91e..2d616408 100644 --- a/compiler/lcalc/compile_with_exceptions.ml +++ b/compiler/lcalc/compile_with_exceptions.ml @@ -74,9 +74,9 @@ and translate_expr (ctx : 'm ctx) (e : 'm D.expr) : 'm A.expr boxed = | EDefault { excepts; just; cons } -> translate_default ctx excepts just cons (Mark.get e) | EOp { op; tys } -> Expr.eop (Operator.translate op) tys m - | ( ELit _ | EApp _ | EArray _ | EVar _ | EAbs _ | EIfThenElse _ | ETuple _ - | ETupleAccess _ | EInj _ | EAssert _ | EStruct _ | EStructAccess _ - | EMatch _ ) as e -> + | ( ELit _ | EApp _ | EArray _ | EVar _ | EExternal _ | EAbs _ | EIfThenElse _ + | ETuple _ | ETupleAccess _ | EInj _ | EAssert _ | EStruct _ + | EStructAccess _ | EMatch _ ) as e -> Expr.map ~f:(translate_expr ctx) (Mark.add m e) | _ -> . diff --git a/compiler/lcalc/compile_without_exceptions.ml b/compiler/lcalc/compile_without_exceptions.ml index 36433604..e093bef3 100644 --- a/compiler/lcalc/compile_without_exceptions.ml +++ b/compiler/lcalc/compile_without_exceptions.ml @@ -107,6 +107,7 @@ let rec trans (ctx : typed ctx) (e : typed D.expr) : (lcalc, typed) boxed_gexpr if (Var.Map.find x ctx.ctx_vars).info_pure then Ast.OptionMonad.return (Expr.evar (trans_var ctx x) m) ~mark else Expr.evar (trans_var ctx x) m + | EExternal eref -> Expr.eexternal eref mark | EApp { f = EVar v, _; args = [(ELit LUnit, _)] } -> (* Invariant: as users cannot write thunks, it can only come from prior compilation passes. Hence we can safely remove those. *) diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index 076c1886..c497e8a3 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -87,6 +87,8 @@ let avoid_keywords (s : string) : string = | "while" | "with" | "Stdlib" | "Runtime" | "Oper" -> s ^ "_user" | _ -> s +(* Fixme: this could cause clashes if the user program contains both e.g. [new] + and [new_user] *) let format_struct_name (fmt : Format.formatter) (v : StructName.t) : unit = Format.asprintf "%a" StructName.format_t v @@ -230,6 +232,7 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : in match Mark.remove e with | EVar v -> Format.fprintf fmt "%a" format_var v + | EExternal qid -> Qident.format fmt qid | ETuple es -> Format.fprintf fmt "@[(%a)@]" (Format.pp_print_list @@ -520,14 +523,15 @@ let rec format_scope_body_expr let format_code_items (ctx : decl_ctx) (fmt : Format.formatter) - (code_items : 'm Ast.expr code_item_list) : unit = + (code_items : 'm Ast.expr code_item_list) : 'm Ast.expr Var.t String.Map.t = Scope.fold_left - ~f:(fun () item var -> + ~f:(fun bnd item var -> match item with - | Topdef (_, typ, e) -> + | Topdef (name, typ, e) -> Format.fprintf fmt "@\n@\n@[let %a : %a =@\n%a@]" format_var var - format_typ typ (format_expr ctx) e - | ScopeDef (_, body) -> + format_typ typ (format_expr ctx) e; + String.Map.add (Mark.remove (TopdefName.get_info name)) var bnd + | ScopeDef (name, body) -> let scope_input_var, scope_body_expr = Bindlib.unbind body.scope_body_expr in @@ -536,22 +540,52 @@ let format_code_items (`Sname body.scope_body_input_struct) format_to_module_name (`Sname body.scope_body_output_struct) (format_scope_body_expr ctx) - scope_body_expr) - ~init:() code_items + scope_body_expr; + String.Map.add (Mark.remove (ScopeName.get_info name)) var bnd) + ~init:String.Map.empty code_items + +let format_module_registration + fmt + (bnd : 'm Ast.expr Var.t String.Map.t) + modname = + Format.pp_open_vbox fmt 2; + Format.pp_print_string fmt "let () ="; + Format.pp_print_space fmt (); + Format.pp_open_hvbox fmt 2; + Format.fprintf fmt "Runtime_ocaml.Runtime.register_module %S" modname; + Format.pp_print_space fmt (); + Format.pp_open_vbox fmt 2; + Format.pp_print_string fmt "[ "; + Format.pp_print_seq + ~pp_sep:(fun fmt () -> Format.pp_print_char fmt ';'; Format.pp_print_cut fmt ()) + (fun fmt (id, var) -> + Format.fprintf fmt "@[%S,@ Obj.repr %a@]" id format_var var) + fmt (String.Map.to_seq bnd); + Format.pp_close_box fmt (); + Format.pp_print_char fmt ' '; + Format.pp_print_string fmt "]"; + Format.pp_print_space fmt (); + Format.pp_print_string fmt "\"todo-module-hash\""; + Format.pp_close_box fmt (); + Format.pp_close_box fmt () + +let header = + {ocaml| +(** This file has been generated by the Catala compiler, do not edit! *) + +open Runtime_ocaml.Runtime + +[@@@ocaml.warning "-4-26-27-32-41-42"] + +|ocaml} let format_program (fmt : Format.formatter) + ?modname (p : 'm Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) : unit = - Format.fprintf fmt - "(** This file has been generated by the Catala compiler, do not edit! *)@\n\ - @\n\ - open Runtime_ocaml.Runtime@\n\ - @\n\ - [@@@@@@ocaml.warning \"-4-26-27-32-41-42\"]@\n\ - @\n\ - %a%a@\n\ - @?" - (format_ctx type_ordering) p.decl_ctx - (format_code_items p.decl_ctx) - p.code_items + Format.pp_print_string fmt header; + format_ctx type_ordering fmt p.decl_ctx; + let bnd = format_code_items p.decl_ctx fmt p.code_items in + Format.pp_print_newline fmt (); + Option.iter (format_module_registration fmt bnd) modname diff --git a/compiler/lcalc/to_ocaml.mli b/compiler/lcalc/to_ocaml.mli index 3c511b61..8d6eedb7 100644 --- a/compiler/lcalc/to_ocaml.mli +++ b/compiler/lcalc/to_ocaml.mli @@ -40,7 +40,9 @@ val format_var : Format.formatter -> 'm Var.t -> unit val format_program : Format.formatter -> + ?modname:string -> 'm Ast.program -> Scopelang.Dependency.TVertex.t list -> unit -(** Usage [format_program fmt p type_dependencies_ordering] *) +(** Usage [format_program fmt p type_dependencies_ordering]. If [modname] is + set, registers the module for dynamic loading *) diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml index f5910078..8de5adf7 100644 --- a/compiler/plugins/lazy_interp.ml +++ b/compiler/plugins/lazy_interp.ml @@ -209,6 +209,7 @@ let rec lazy_eval : | (ELit (LBool false), _), _ -> error e "Assert failure (%a)" Expr.format e | _ -> error e "Invalid assertion condition %a" Expr.format e) + | EExternal _, _ -> assert false (* todo *) | _ -> . let interpret_program (prg : ('dcalc, 'm) gexpr program) (scope : ScopeName.t) : diff --git a/compiler/scopelang/from_desugared.ml b/compiler/scopelang/from_desugared.ml index d3a14a72..a25dda45 100644 --- a/compiler/scopelang/from_desugared.ml +++ b/compiler/scopelang/from_desugared.ml @@ -46,6 +46,17 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : untyped Ast.expr boxed = let m = Mark.get e in match Mark.remove e with + | EVar v -> Expr.evar (Var.Map.find v ctx.var_mapping) m + | EAbs { binder; tys } -> + let vars, body = Bindlib.unmbind binder in + let new_vars = Array.map (fun var -> Var.make (Bindlib.name_of var)) vars in + let ctx = + List.fold_left2 + (fun ctx var new_var -> + { ctx with var_mapping = Var.Map.add var new_var ctx.var_mapping }) + ctx (Array.to_list vars) (Array.to_list new_vars) + in + Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) tys m | ELocation (SubScopeVar (s_name, ss_name, s_var)) -> (* When referring to a subscope variable in an expression, we are referring to the output, hence we take the last state. *) @@ -70,9 +81,6 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : | States states -> Mark.copy s_var (List.assoc state states))) m | ELocation (ToplevelVar v) -> Expr.elocation (ToplevelVar v) m - | EVar v -> Expr.evar (Var.Map.find v ctx.var_mapping) m - | EStruct { name; fields } -> - Expr.estruct name (StructField.Map.map (translate_expr ctx) fields) m | EDStructAccess { name_opt = None; _ } -> (* Note: this could only happen if disambiguation was disabled. If we want to support it, we should still allow this case when the field has only @@ -93,14 +101,6 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : field StructName.format_t name in Expr.estructaccess e' field name m - | ETuple es -> Expr.etuple (List.map (translate_expr ctx) es) m - | ETupleAccess { e; index; size } -> - Expr.etupleaccess (translate_expr ctx e) index size m - | EInj { e; cons; name } -> Expr.einj (translate_expr ctx e) cons name m - | EMatch { e; name; cases } -> - Expr.ematch (translate_expr ctx e) name - (EnumConstructor.Map.map (translate_expr ctx) cases) - m | EScopeCall { scope; args } -> Expr.escopecall scope (ScopeVar.Map.fold @@ -117,20 +117,6 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : ScopeVar.Map.add v' (translate_expr ctx e) args') args ScopeVar.Map.empty) m - | ELit - ((LBool _ | LInt _ | LRat _ | LMoney _ | LUnit | LDate _ | LDuration _) as - l) -> - Expr.elit l m - | EAbs { binder; tys } -> - let vars, body = Bindlib.unmbind binder in - let new_vars = Array.map (fun var -> Var.make (Bindlib.name_of var)) vars in - let ctx = - List.fold_left2 - (fun ctx var new_var -> - { ctx with var_mapping = Var.Map.add var new_var ctx.var_mapping }) - ctx (Array.to_list vars) (Array.to_list new_vars) - in - Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) tys m | EApp { f = EOp { op; tys }, m1; args } -> let args = List.map (translate_expr ctx) args in Operator.kind_dispatch op @@ -144,19 +130,10 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) : | op, `Reversed -> Expr.eapp (Expr.eop op (List.rev tys) m1) (List.rev args) m) | EOp _ -> assert false (* Only allowed within [EApp] *) - | EApp { f; args } -> - Expr.eapp (translate_expr ctx f) (List.map (translate_expr ctx) args) m - | EDefault { excepts; just; cons } -> - Expr.edefault - (List.map (translate_expr ctx) excepts) - (translate_expr ctx just) (translate_expr ctx cons) m - | EIfThenElse { cond; etrue; efalse } -> - Expr.eifthenelse (translate_expr ctx cond) (translate_expr ctx etrue) - (translate_expr ctx efalse) - m - | EArray args -> Expr.earray (List.map (translate_expr ctx) args) m - | EEmptyError -> Expr.eemptyerror m - | EErrorOnEmpty e1 -> Expr.eerroronempty (translate_expr ctx e1) m + | ( EStruct _ | ETuple _ | ETupleAccess _ | EInj _ | EMatch _ | ELit _ + | EApp _ | EDefault _ | EIfThenElse _ | EArray _ | EEmptyError + | EErrorOnEmpty _ | EExternal _ ) as e -> + Expr.map ~f:(translate_expr ctx) (e, m) (** {1 Rule tree construction} *) diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index d0e2f9cb..bbe0179a 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -56,8 +56,13 @@ module StateName = Uid.Gen () (** These types allow to select the features present in any given expression type *) -type yes = private Yes -type no = | +type yes = Yes + +type no = + | No + (** Phantom types used in the definitions below. We don't make them + abstract, because the typer needs to know that their intersection is + empty. *) type desugared = < monomorphic : yes @@ -71,7 +76,8 @@ type desugared = ; explicitScopes : yes ; assertions : no ; defaultTerms : yes - ; exceptions : no > + ; exceptions : no + ; custom : no > type scopelang = < monomorphic : yes @@ -85,7 +91,8 @@ type scopelang = ; explicitScopes : yes ; assertions : no ; defaultTerms : yes - ; exceptions : no > + ; exceptions : no + ; custom : no > type dcalc = < monomorphic : yes @@ -99,7 +106,8 @@ type dcalc = ; explicitScopes : no ; assertions : yes ; defaultTerms : yes - ; exceptions : no > + ; exceptions : no + ; custom : no > type lcalc = < monomorphic : yes @@ -113,7 +121,8 @@ type lcalc = ; explicitScopes : no ; assertions : yes ; defaultTerms : no - ; exceptions : yes > + ; exceptions : yes + ; custom : no > type 'a any = < .. > as 'a (** ['a any] is 'a, but adds the constraint that it should be restricted to @@ -131,7 +140,8 @@ type ('a, 'b) dcalc_lcalc = ; explicitScopes : no ; assertions : yes ; defaultTerms : 'a - ; exceptions : 'b > + ; exceptions : 'b + ; custom : no > (** This type regroups Dcalc and Lcalc ASTs. *) (** {2 Types} *) @@ -379,6 +389,7 @@ and ('a, 'b, 'm) base_gexpr = -> ('a, (< .. > as 'b), 'm) base_gexpr | EArray : ('a, 'm) gexpr list -> ('a, < .. >, 'm) base_gexpr | EVar : ('a, 'm) naked_gexpr Bindlib.var -> ('a, _, 'm) base_gexpr + | EExternal : Qident.t -> ('a, < .. >, 't) base_gexpr | EAbs : { binder : (('a, 'a, 'm) base_gexpr, ('a, 'm) gexpr) Bindlib.mbinder; tys : typ list; @@ -456,6 +467,16 @@ and ('a, 'b, 'm) base_gexpr = handler : ('a, 'm) gexpr; } -> ('a, < exceptions : yes ; .. >, 'm) base_gexpr + (* Only used during evaluation *) + | ECustom : { + obj : Obj.t; + targs : typ list; + tret : typ; + } + -> ('a, < custom : yes ; .. >, 't) base_gexpr + (** A function of the given type, as a runtime OCaml object. The specified + types for arguments and result must be the Catala types corresponding + to the runtime types of the function. *) (** Useful for errors and printing, for example *) type any_expr = AnyExpr : ('a, _) gexpr -> any_expr @@ -552,6 +573,7 @@ type decl_ctx = { ctx_struct_fields : StructField.t StructName.Map.t Ident.Map.t; (** needed for disambiguation (desugared -> scope) *) ctx_scopes : scope_out_struct ScopeName.Map.t; + ctx_modules : typ Qident.Map.t; } type 'e program = { decl_ctx : decl_ctx; code_items : 'e code_item_list } diff --git a/compiler/shared_ast/expr.ml b/compiler/shared_ast/expr.ml index b24d5fdc..bd952936 100644 --- a/compiler/shared_ast/expr.ml +++ b/compiler/shared_ast/expr.ml @@ -109,6 +109,7 @@ let subst binder vars = Bindlib.msubst binder (Array.of_list (List.map Mark.remove vars)) let evar v mark = Mark.add mark (Bindlib.box_var v) +let eexternal eref mark = Mark.add mark (Bindlib.box (EExternal eref)) let etuple args = Box.appn args @@ fun args -> ETuple args let etupleaccess e index size = @@ -140,6 +141,9 @@ let eraise e1 = Box.app0 @@ ERaise e1 let ecatch body exn handler = Box.app2 body handler @@ fun body handler -> ECatch { body; exn; handler } +let ecustom obj targs tret mark = + Mark.add mark (Bindlib.box (ECustom { obj; targs; tret })) + let elocation loc = Box.app0 @@ ELocation loc let estruct name (fields : ('a, 't) boxed_gexpr StructField.Map.t) mark = @@ -268,6 +272,7 @@ let map | EOp { op; tys } -> eop op tys m | EArray args -> earray (List.map f args) m | EVar v -> evar (Var.translate v) m + | EExternal eref -> eexternal eref m | EAbs { binder; tys } -> let vars, body = Bindlib.unmbind binder in let body = f body in @@ -298,6 +303,7 @@ let map | EScopeCall { scope; args } -> let fields = ScopeVar.Map.map f args in escopecall scope fields m + | ECustom { obj; targs; tret } -> ecustom obj targs tret m let rec map_top_down ~f e = map ~f:(map_top_down ~f) (f e) let map_marks ~f e = map_top_down ~f:(Mark.map_mark f) e @@ -310,7 +316,9 @@ let shallow_fold (acc : 'acc) : 'acc = let lfold x acc = List.fold_left (fun acc x -> f x acc) acc x in match Mark.remove e with - | ELit _ | EOp _ | EVar _ | ERaise _ | ELocation _ | EEmptyError -> acc + | ELit _ | EOp _ | EVar _ | EExternal _ | ERaise _ | ELocation _ | EEmptyError + -> + acc | EApp { f = e; args } -> acc |> f e |> lfold args | EArray args -> acc |> lfold args | EAbs { binder; tys = _ } -> @@ -330,6 +338,7 @@ let shallow_fold | EMatch { e; cases; _ } -> acc |> f e |> EnumConstructor.Map.fold (fun _ -> f) cases | EScopeCall { args; _ } -> acc |> ScopeVar.Map.fold (fun _ -> f) args + | ECustom _ -> acc (* Like [map], but also allows to gather a result bottom-up. *) let map_gather @@ -360,6 +369,7 @@ let map_gather let acc, args = lfoldmap args in acc, earray args m | EVar v -> acc, evar (Var.translate v) m + | EExternal eref -> acc, eexternal eref m | EAbs { binder; tys } -> let vars, body = Bindlib.unmbind binder in let acc, body = f body in @@ -433,6 +443,7 @@ let map_gather args (acc, ScopeVar.Map.empty) in acc, escopecall scope args m + | ECustom { obj; targs; tret } -> acc, ecustom obj targs tret m (* - *) @@ -541,6 +552,7 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = fun e1 e2 -> match Mark.remove e1, Mark.remove e2 with | EVar v1, EVar v2 -> Bindlib.eq_vars v1 v2 + | EExternal eref1, EExternal eref2 -> Qident.equal eref1 eref2 | ETuple es1, ETuple es2 -> equal_list es1 es2 | ( ETupleAccess { e = e1; index = id1; size = s1 }, ETupleAccess { e = e2; index = id2; size = s2 } ) -> @@ -588,10 +600,14 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool = | ( EScopeCall { scope = s1; args = fields1 }, EScopeCall { scope = s2; args = fields2 } ) -> ScopeName.equal s1 s2 && ScopeVar.Map.equal equal fields1 fields2 - | ( ( EVar _ | ETuple _ | ETupleAccess _ | EArray _ | ELit _ | EAbs _ | EApp _ - | EAssert _ | EOp _ | EDefault _ | EIfThenElse _ | EEmptyError - | EErrorOnEmpty _ | ERaise _ | ECatch _ | ELocation _ | EStruct _ - | EDStructAccess _ | EStructAccess _ | EInj _ | EMatch _ | EScopeCall _ ), + | ( ECustom { obj = obj1; targs = targs1; tret = tret1 }, + ECustom { obj = obj2; targs = targs2; tret = tret2 } ) -> + Type.equal_list targs1 targs2 && Type.equal tret1 tret2 && obj1 == obj2 + | ( ( EVar _ | EExternal _ | ETuple _ | ETupleAccess _ | EArray _ | ELit _ + | EAbs _ | EApp _ | EAssert _ | EOp _ | EDefault _ | EIfThenElse _ + | EEmptyError | EErrorOnEmpty _ | ERaise _ | ECatch _ | ELocation _ + | EStruct _ | EDStructAccess _ | EStructAccess _ | EInj _ | EMatch _ + | EScopeCall _ | ECustom _ ), _ ) -> false @@ -614,6 +630,8 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int = List.compare compare a1 a2 | EVar v1, EVar v2 -> Bindlib.compare_vars v1 v2 + | EExternal eref1, EExternal eref2 -> + Qident.compare eref1 eref2 | EAbs {binder=binder1; tys=typs1}, EAbs {binder=binder2; tys=typs2} -> List.compare Type.compare typs1 typs2 @@< fun () -> @@ -678,11 +696,15 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int = compare_except ex1 ex2 @@< fun () -> compare etry1 etry2 @@< fun () -> compare ewith1 ewith2 + | ECustom _, _ | _, ECustom _ -> + (* fixme: ideally this would be forbidden by typing *) + invalid_arg "Custom block comparison" | ELit _, _ -> -1 | _, ELit _ -> 1 | EApp _, _ -> -1 | _, EApp _ -> 1 | EOp _, _ -> -1 | _, EOp _ -> 1 | EArray _, _ -> -1 | _, EArray _ -> 1 | EVar _, _ -> -1 | _, EVar _ -> 1 + | EExternal _, _ -> -1 | _, EExternal _ -> 1 | EAbs _, _ -> -1 | _, EAbs _ -> 1 | EIfThenElse _, _ -> -1 | _, EIfThenElse _ -> 1 | ELocation _, _ -> -1 | _, ELocation _ -> 1 @@ -735,7 +757,7 @@ let format ppf e = Print.expr ~debug:false () ppf e let rec size : type a. (a, 't) gexpr -> int = fun e -> match Mark.remove e with - | EVar _ | ELit _ | EOp _ | EEmptyError -> 1 + | EVar _ | EExternal _ | ELit _ | EOp _ | EEmptyError | ECustom _ -> 1 | ETuple args -> List.fold_left (fun acc arg -> acc + size arg) 1 args | EArray args -> List.fold_left (fun acc arg -> acc + size arg) 1 args | ETupleAccess { e; _ } -> size e + 1 diff --git a/compiler/shared_ast/expr.mli b/compiler/shared_ast/expr.mli index dbc9e9c1..65363829 100644 --- a/compiler/shared_ast/expr.mli +++ b/compiler/shared_ast/expr.mli @@ -32,6 +32,7 @@ val rebox : ('a any, 'm) gexpr -> ('a, 'm) boxed_gexpr (** Rebuild the whole term, re-binding all variables and exposing free variables *) val evar : ('a, 'm) gexpr Var.t -> 'm mark -> ('a, 'm) boxed_gexpr +val eexternal : Qident.t -> 'm mark -> ('a any, 'm) boxed_gexpr val bind : ('a, 'm) gexpr Var.t array -> @@ -142,6 +143,13 @@ val escopecall : 'm mark -> ((< explicitScopes : yes ; .. > as 'a), 'm) boxed_gexpr +val ecustom : + Obj.t -> + Type.t list -> + Type.t -> + 'm mark -> + (< custom : Definitions.yes ; .. >, 'm) boxed_gexpr + val fun_id : 'm mark -> ('a any, 'm) boxed_gexpr (** {2 Manipulation of marks} *) diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index 68455cd8..91438278 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -23,6 +23,21 @@ open Definitions open Op module Runtime = Runtime_ocaml.Runtime +type features = + < monomorphic : yes + ; polymorphic : yes + ; overloaded : no + ; resolved : yes + ; syntacticNames : no + ; resolvedNames : yes + ; scopeVarStates : no + ; scopeVarSimpl : no + ; explicitScopes : no + ; assertions : yes > + +type ('d, 'e, 'c) astk = + < features ; defaultTerms : 'd ; exceptions : 'e ; custom : 'c > + (** {1 Helpers} *) let is_empty_error : type a. (a, 'm) gexpr -> bool = @@ -375,10 +390,226 @@ let rec evaluate_operator _ ) -> err () +(* /S\ dark magic here. This relies both on internals of [Lcalc.to_ocaml] *and* + of the OCaml runtime *) +let rec runtime_to_val : + (decl_ctx -> ('a, 'm) gexpr -> ('a, 'm) gexpr) -> + decl_ctx -> + 'm mark -> + typ -> + Obj.t -> + (((_, _, yes) astk as 'a), 'm) gexpr = + fun eval_expr ctx m ty o -> + let m = Expr.map_ty (fun _ -> ty) m in + match Mark.remove ty with + | TLit TBool -> ELit (LBool (Obj.obj o)), m + | TLit TUnit -> ELit LUnit, m + | TLit TInt -> ELit (LInt (Obj.obj o)), m + | TLit TRat -> ELit (LRat (Obj.obj o)), m + | TLit TMoney -> ELit (LMoney (Obj.obj o)), m + | TLit TDate -> ELit (LDate (Obj.obj o)), m + | TLit TDuration -> ELit (LDuration (Obj.obj o)), m + | TTuple ts -> + ( ETuple + (List.map2 + (runtime_to_val eval_expr ctx m) + ts + (Array.to_list (Obj.obj o))), + m ) + | TStruct name -> + StructName.Map.find name ctx.ctx_structs + |> StructField.Map.to_seq + |> Seq.map2 + (fun o (fld, ty) -> fld, runtime_to_val eval_expr ctx m ty o) + (Array.to_seq (Obj.obj o)) + |> StructField.Map.of_seq + |> fun fields -> EStruct { name; fields }, m + | TEnum name -> + (* we only use non-constant constructors of arity 1, which allows us to + always use the tag directly (ordered as declared in the constr map), and + the field 0 *) + let cons, ty = + List.nth + (EnumConstructor.Map.bindings (EnumName.Map.find name ctx.ctx_enums)) + (Obj.tag o - Obj.first_non_constant_constructor_tag) + in + let e = runtime_to_val eval_expr ctx m ty (Obj.field o 0) in + EInj { name; cons; e }, m + | TOption _ty -> assert false + | TArray ty -> + ( EArray + (List.map + (runtime_to_val eval_expr ctx m ty) + (Array.to_list (Obj.obj o))), + m ) + | TArrow (targs, tret) -> ECustom { obj = o; targs; tret }, m + | TAny -> assert false + +and val_to_runtime : + (decl_ctx -> ('a, 'm) gexpr -> ('a, 'm) gexpr) -> + decl_ctx -> + typ -> + ('b, 'm) gexpr -> + Obj.t = + fun eval_expr ctx ty v -> + match Mark.remove ty, Mark.remove v with + | TLit TBool, ELit (LBool b) -> Obj.repr b + | TLit TUnit, ELit LUnit -> Obj.repr () + | TLit TInt, ELit (LInt i) -> Obj.repr i + | TLit TRat, ELit (LRat r) -> Obj.repr r + | TLit TMoney, ELit (LMoney m) -> Obj.repr m + | TLit TDate, ELit (LDate t) -> Obj.repr t + | TLit TDuration, ELit (LDuration d) -> Obj.repr d + | TTuple ts, ETuple es -> + List.map2 (val_to_runtime eval_expr ctx) ts es |> Array.of_list |> Obj.repr + | TStruct name1, EStruct { name; fields } -> + assert (StructName.equal name name1); + let fld_tys = StructName.Map.find name ctx.ctx_structs in + Seq.map2 + (fun (_, ty) (_, v) -> val_to_runtime eval_expr ctx ty v) + (StructField.Map.to_seq fld_tys) + (StructField.Map.to_seq fields) + |> Array.of_seq + |> Obj.repr + | TEnum name1, EInj { name; cons; e } -> + assert (EnumName.equal name name1); + let rec find_tag n = function + | [] -> assert false + | (c, ty) :: _ when EnumConstructor.equal c cons -> n, ty + | _ :: r -> find_tag (n + 1) r + in + let tag, ty = + find_tag Obj.first_non_constant_constructor_tag + (EnumConstructor.Map.bindings (EnumName.Map.find name ctx.ctx_enums)) + in + let o = Obj.with_tag tag (Obj.repr (Some ())) in + Obj.set_field o 0 (val_to_runtime eval_expr ctx ty e); + o + | TOption _ty, _ -> assert false + | TArray ty, EArray es -> + Array.of_list (List.map (val_to_runtime eval_expr ctx ty) es) |> Obj.repr + | TArrow (targs, tret), _ -> + let m = Mark.get v in + (* we want stg like [fun args -> val_to_runtime (eval_expr ctx (EApp (v, + args)))] but in curried form *) + let rec curry acc = function + | [] -> + let args = List.rev acc in + val_to_runtime eval_expr ctx tret + (eval_expr ctx (EApp { f = v; args }, m)) + | targ :: targs -> + Obj.repr (fun x -> + curry (runtime_to_val eval_expr ctx m targ x :: acc) targs) + in + curry [] targs + | _ -> assert false + +(* let f e = (e : (< .. > as 'a, 't) gexpr :> (< custom : _; 'a; .. >, 't) gexpr ) + * + * let f (type a) ((e: (< custom: a; .. >, 't) naked_gexpr), t) = + * if false then ECustom { obj = Obj.repr (); targs = []; tret = (TLit TUnit, Pos.no_pos) }, t + * else e, t *) +(* let rec add_custom: (< .. > + * type a b . (a, b, 't) base_gexpr * 't -> (< custom: yes; .. >, 't) gexpr + * = function + * | ECustom _, _ as e -> Expr.box e + * | (ELit _ + * | EApp _ + * | EOp _ + * | EArray _ + * | EVar _ + * | EExternal _ + * | EAbs _ + * | EIfThenElse _ + * | ETuple _ + * | ETupleAccess _ + * | EInj _ + * | EAssert _ + * | EDefault _ + * | EEmptyError + * | EErrorOnEmpty _ + * | ECatch _ + * | ERaise _ + * | ELocation _ + * | EStruct _ + * | EDStructAccess _ + * | EStructAccess _ + * | EMatch _ + * | EScopeCall _), _ + * as e + * -> Expr.map ~f:add_custom e + * + * ;; + * fun e -> + * if false then + * Expr.box + * (ECustom { obj = Obj.repr (); targs = []; tret = (TLit TUnit, Pos.no_pos) }, + * Marked.get_mark e) + * else *) + +(* type ('a, 'b) has_custom = < custom: 'a; .. > as 'b + * + * let f (type b) (e: ((_, b) has_custom, 't) naked_gexpr) : ((yes, b) has_custom, 't) naked_gexpr = match e with + * | EEmptyError when false -> + * ECustom { obj = Obj.repr (); targs = []; tret = (TLit TUnit, Pos.no_pos) } + * | ECustom _ as e -> e + * | EOp _ as e -> e + * | ELocation _ as e -> e + * | ELit _ as e -> e + * | EApp _ as e -> e + * | EArray _ as e -> e + * | EVar _ as e -> e + * | EExternal _ as e -> e + * | EAbs _ as e -> e + * | EIfThenElse _ as e -> e + * | ETuple _ as e -> e + * | ETupleAccess _ as e -> e + * | EInj _ as e -> e + * | EAssert _ as e -> e + * | EDefault _ as e -> e + * | EEmptyError as e -> e + * | EErrorOnEmpty _ as e -> e + * | ECatch _ as e -> e + * | ERaise _ as e -> e + * | EStruct _ as e -> e + * | EDStructAccess _ as e -> e + * | EStructAccess _ as e -> e + * | EMatch _ as e -> e + * | EScopeCall _ as e -> e *) + +(* let rec add_custom: + * type c d e. + * (< features; defaultTerms: d; exceptions: e; custom : c >, 't) gexpr -> + * (< features; defaultTerms: d; exceptions: e; custom : yes >, 't) gexpr boxed + * = function + * (\* Obj.magic (Expr.box e) *\) + * (\* | EEmptyError, m when false -> + * * Expr.box (ECustom { obj = Obj.repr (); targs = []; tret = (TLit TUnit, Pos.no_pos) }, m) *\) + * | (EDefault _ | EEmptyError | EErrorOnEmpty _), _ as e -> + * Expr.map ~f:add_custom + * (e : (< features; defaultTerms: yes; exceptions: e; custom : c >, 't) gexpr) + * | (ECatch _ | ERaise _), _ as e -> Expr.map ~f:add_custom e + * | ECustom _, _ -> assert false + * | (EOp _ + * | ELocation _ + * | ELit _ + * | EApp _ + * | EArray _ + * | EVar _ + * | EExternal _ + * | EAbs _ + * | EIfThenElse _ + * | ETuple _ + * | ETupleAccess _ + * | EInj _ + * | EAssert _ + * | EStruct _ + * | EStructAccess _ + * | EMatch _), _ as e -> Expr.map ~f:add_custom e *) + let rec evaluate_expr : - type a b. - decl_ctx -> ((a, b) dcalc_lcalc, 'm) gexpr -> ((a, b) dcalc_lcalc, 'm) gexpr - = + type d e. + decl_ctx -> ((d, e, yes) astk, 't) gexpr -> ((d, e, yes) astk, 't) gexpr = fun ctx e -> let m = Mark.get e in let pos = Expr.mark_pos m in @@ -387,6 +618,14 @@ let rec evaluate_expr : Message.raise_spanned_error pos "free variable found at evaluation (should not happen if term was \ well-typed)" + | EExternal qid -> ( + match Qident.Map.find_opt qid ctx.ctx_modules with + | None -> + Message.raise_spanned_error pos "Reference to %a could not be resolved" + Qident.format qid + | Some ty -> + let o = Runtime.lookup_value qid in + runtime_to_val evaluate_expr ctx m ty o) | EApp { f = e1; args } -> ( let e1 = evaluate_expr ctx e1 in let args = List.map (evaluate_expr ctx) args in @@ -403,11 +642,23 @@ let rec evaluate_expr : (Bindlib.mbinder_arity binder) (List.length args) | EOp { op; _ } -> evaluate_operator (evaluate_expr ctx) op m args + | ECustom { obj; targs; tret } -> + (* Applies the arguments one by one to the curried form *) + List.fold_left2 + (fun fobj targ arg -> + (Obj.obj fobj : Obj.t -> Obj.t) + (val_to_runtime evaluate_expr ctx targ arg)) + obj targs args + |> Obj.obj + |> fun o -> runtime_to_val evaluate_expr ctx m tret o | _ -> Message.raise_spanned_error pos "function has not been reduced to a lambda at evaluation (should not \ happen if the term was well-typed") - | (EAbs _ | ELit _ | EOp _) as e -> Mark.add m e (* these are values *) + | EAbs { binder; tys } -> Expr.unbox (Expr.eabs (Bindlib.box binder) tys m) + | ELit _ as e -> Mark.add m e + | EOp { op; tys } -> Expr.unbox (Expr.eop (Operator.translate op) tys m) + (* | EAbs _ as e -> Marked.mark m e (* these are values *) *) | EStruct { fields = es; name } -> let fields, es = List.split (StructField.Map.bindings es) in let es = List.map (evaluate_expr ctx) es in @@ -514,6 +765,7 @@ let rec evaluate_expr : Message.raise_spanned_error (Expr.pos e') "Expected a boolean literal for the result of this assertion \ (should not happen if the term was well-typed)") + | ECustom _ -> e | EEmptyError -> Mark.copy e EEmptyError | EErrorOnEmpty e' -> ( match evaluate_expr ctx e' with @@ -552,6 +804,142 @@ let rec evaluate_expr : evaluate_expr ctx handler) | _ -> . +(* type ('kind,'a,'b,'c,'d,'e,'f,'g,'h,'i,'k,'l,'m) astrec = { + * monomorphic : 'a + * ; polymorphic : 'b + * ; overloaded : 'c + * ; resolved : 'd + * ; syntacticNames : 'e + * ; resolvedNames : 'f + * ; scopeVarStates : 'g + * ; scopeVarSimpl : 'h + * ; explicitScopes : 'i + * ; assertions : 'j + * ; defaultTerms : 'k + * ; exceptions : 'l + * ; custom : 'm + * } + * constraint 'kind = < + * monomorphic : 'a + * ; polymorphic : 'b + * ; overloaded : 'c + * ; resolved : 'd + * ; syntacticNames : 'e + * ; resolvedNames : 'f + * ; scopeVarStates : 'g + * ; scopeVarSimpl : 'h + * ; explicitScopes : 'i + * ; assertions : 'j + * ; defaultTerms : 'k + * ; exceptions : 'l + * ; custom : 'm > + * + * type ('kind,'a,'b,'c,'d,'e,'f,'g,'h,'i,'k,'l,'m) astrec2 = + * Astrec: + * { monomorphic : 'a + * ; polymorphic : 'b + * ; overloaded : 'c + * ; resolved : 'd + * ; syntacticNames : 'e + * ; resolvedNames : 'f + * ; scopeVarStates : 'g + * ; scopeVarSimpl : 'h + * ; explicitScopes : 'i + * ; assertions : 'j + * ; defaultTerms : 'k + * ; exceptions : 'l + * ; custom : 'm } + * -> + * (< monomorphic : 'a + * ; polymorphic : 'b + * ; overloaded : 'c + * ; resolved : 'd + * ; syntacticNames : 'e + * ; resolvedNames : 'f + * ; scopeVarStates : 'g + * ; scopeVarSimpl : 'h + * ; explicitScopes : 'i + * ; assertions : 'j + * ; defaultTerms : 'k + * ; exceptions : 'l + * ; custom : 'm >, + * 'a,'b,'c,'d,'e,'f,'g,'h,'i,'k,'l,'m) astrec2 + * + * let customise + * (type x a b c d e f g h i j k l m) + * (ty: ('kind,a,b,c,d,e,f,g,h,i,k,l,m) astrec2) + * (e: (x, 't) gexpr) + * : ('kind, 't) gexpr = + * match ty, e with + * | Astrec { custom = Yes; _ }, (ECustom _, _ as e) -> (e: ('kind, 't) gexpr) + * | Astrec { custom = No; _ }, (ECustom _, _) -> invalid_arg "Bad AST cast" + * | EOp {op;tys}, m -> Expr.eop (Operator.translate op) tys m + * | EDefault _, _ as e -> Expr.map ~f e + * | EEmptyError, _ as e -> Expr.map ~f e + * | EErrorOnEmpty _, _ as e -> Expr.map ~f e + * | ECatch _, _ as e -> Expr.map ~f e + * | ERaise _, _ as e -> Expr.map ~f e + * | (EAssert _ + * | ELit _ + * | EApp _ + * | EArray _ + * | EVar _ + * | EExternal _ + * | EAbs _ + * | EIfThenElse _ + * | ETuple _ + * | ETupleAccess _ + * | EInj _ + * | EStruct _ + * | EStructAccess _ + * | EMatch _), _ as e -> Expr.map ~f e + * | _ -> . *) + +(* Typing shenanigan to add custom terms to the AST type. This is an identity + and could be optimised into [Obj.magic]. *) +let addcustom e = + let rec f : + type c d e. + ((d, e, c) astk, 't) gexpr -> ((d, e, yes) astk, 't) gexpr boxed = + function + | (ECustom _, _) as e -> Expr.map ~f e + | EOp { op; tys }, m -> Expr.eop (Operator.translate op) tys m + | (EDefault _, _) as e -> Expr.map ~f e + | (EEmptyError, _) as e -> Expr.map ~f e + | (EErrorOnEmpty _, _) as e -> Expr.map ~f e + | (ECatch _, _) as e -> Expr.map ~f e + | (ERaise _, _) as e -> Expr.map ~f e + | ( ( EAssert _ | ELit _ | EApp _ | EArray _ | EVar _ | EExternal _ | EAbs _ + | EIfThenElse _ | ETuple _ | ETupleAccess _ | EInj _ | EStruct _ + | EStructAccess _ | EMatch _ ), + _ ) as e -> + Expr.map ~f e + | _ -> . + in + Expr.unbox (f e) + +let delcustom e = + let rec f : + type c d e. + ((d, e, c) astk, 't) gexpr -> ((d, e, no) astk, 't) gexpr boxed = function + | ECustom _, _ -> invalid_arg "Custom term remaining in evaluated term" + | EOp { op; tys }, m -> Expr.eop (Operator.translate op) tys m + | (EDefault _, _) as e -> Expr.map ~f e + | (EEmptyError, _) as e -> Expr.map ~f e + | (EErrorOnEmpty _, _) as e -> Expr.map ~f e + | (ECatch _, _) as e -> Expr.map ~f e + | (ERaise _, _) as e -> Expr.map ~f e + | ( ( EAssert _ | ELit _ | EApp _ | EArray _ | EVar _ | EExternal _ | EAbs _ + | EIfThenElse _ | ETuple _ | ETupleAccess _ | EInj _ | EStruct _ + | EStructAccess _ | EMatch _ ), + _ ) as e -> + Expr.map ~f e + | _ -> . + in + Expr.unbox (f e) + +let evaluate_expr ctx e = delcustom (evaluate_expr ctx (addcustom e)) + let interpret_program_lcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list = let e = Expr.unbox @@ Program.to_expr p s in @@ -601,11 +989,24 @@ let interpret_program_lcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list "The interpreter can only interpret terms starting with functions having \ thunked arguments" +let dynload_modules () = + (* FIXME: WIP placeholder ; also, each file should be loaded only once *) + match Sys.getenv_opt "CATALA_INTF" with + | None | Some "" -> () + | Some str -> + let files = String.split_on_char ',' str in + List.iter + (fun f -> + let mlf = Filename.remove_extension f ^ ".cmxs" in + Dynlink.loadfile mlf) + files + (** {1 API} *) let interpret_program_dcalc p s : (Uid.MarkedString.info * ('a, 'm) gexpr) list = let ctx = p.decl_ctx in let e = Expr.unbox (Program.to_expr p s) in + dynload_modules (); match evaluate_expr p.decl_ctx e with | (EAbs { tys = [((TStruct s_in, _) as _targs)]; _ }, mark_e) as e -> begin (* At this point, the interpreter seeks to execute the scope but does not diff --git a/compiler/shared_ast/interpreter.mli b/compiler/shared_ast/interpreter.mli index a59210e5..778cdd88 100644 --- a/compiler/shared_ast/interpreter.mli +++ b/compiler/shared_ast/interpreter.mli @@ -20,8 +20,21 @@ open Catala_utils open Definitions +type features = + < monomorphic : yes + ; polymorphic : yes + ; overloaded : no + ; resolved : yes + ; syntacticNames : no + ; resolvedNames : yes + ; scopeVarStates : no + ; scopeVarSimpl : no + ; explicitScopes : no + ; assertions : yes > +(** The interpreter only works on dcalc and lcalc, which share these features *) + val evaluate_operator : - ((((_, _) dcalc_lcalc as 'a), 'm) gexpr -> ('a, 'm) gexpr) -> + (((< features ; .. > as 'a), 'm) gexpr -> ('a, 'm) gexpr) -> 'a operator -> 'm mark -> ('a, 'm) gexpr list -> @@ -32,9 +45,7 @@ val evaluate_operator : operator. *) val evaluate_expr : - decl_ctx -> - (('a, 'b) dcalc_lcalc, 'm) gexpr -> - (('a, 'b) dcalc_lcalc, 'm) gexpr + decl_ctx -> (((_, _) dcalc_lcalc as 'a), 'm) gexpr -> ('a, 'm) gexpr (** Evaluates an expression according to the semantics of the default calculus. *) val interpret_program_dcalc : diff --git a/compiler/shared_ast/optimizations.ml b/compiler/shared_ast/optimizations.ml index 263d0179..8d298add 100644 --- a/compiler/shared_ast/optimizations.ml +++ b/compiler/shared_ast/optimizations.ml @@ -184,7 +184,7 @@ let rec optimize_expr : when name = name1 -> Mark.remove (StructField.Map.find field fields) | EDefault { excepts; just; cons } -> ( - (* TODO: mechanically prove each of these optimizations correct :) *) + (* TODO: mechanically prove each of these optimizations correct *) let excepts = List.filter (fun except -> Mark.remove except <> EEmptyError) excepts (* we can discard the exceptions that are always empty error *) @@ -198,7 +198,8 @@ let rec optimize_expr : (* at this point we know a conflict error will be triggered so we just feed the expression to the interpreter that will print the beautiful right error message *) - Mark.remove (Interpreter.evaluate_expr ctx.decl_ctx e) + let _ = Interpreter.evaluate_expr ctx.decl_ctx e in + assert false else match excepts, just with | [except], _ when Expr.is_value except -> @@ -302,7 +303,12 @@ let rec optimize_expr : in Expr.Box.app1 e reduce mark -let optimize_expr (decl_ctx : decl_ctx) (e : (('a, 'b) dcalc_lcalc, 'm) gexpr) = +let optimize_expr : + 'm. + decl_ctx -> + (('a, 'b) dcalc_lcalc, 'm) gexpr -> + (('a, 'b) dcalc_lcalc, 'm) boxed_gexpr = + fun (decl_ctx : decl_ctx) (e : (('a, 'b) dcalc_lcalc, 'm) gexpr) -> optimize_expr { var_values = Var.Map.empty; decl_ctx } e let optimize_program (p : 'm program) : 'm program = @@ -339,15 +345,7 @@ let test_iota_reduction_1 () = x" (Format.asprintf "before=%a\nafter=%a" Expr.format (Expr.unbox matchA) Expr.format - (Expr.unbox - (optimize_expr - { - ctx_enums = EnumName.Map.empty; - ctx_structs = StructName.Map.empty; - ctx_struct_fields = Ident.Map.empty; - ctx_scopes = ScopeName.Map.empty; - } - (Expr.unbox matchA)))) + (Expr.unbox (optimize_expr Program.empty_ctx (Expr.unbox matchA)))) let cases_of_list l : ('a, 't) boxed_gexpr EnumConstructor.Map.t = EnumConstructor.Map.of_seq @@ -409,12 +407,4 @@ let test_iota_reduction_2 () = \ | B → (λ (x: any) → D B x)\n" (Format.asprintf "before=@[%a@]@.after=%a@." Expr.format (Expr.unbox matchA) Expr.format - (Expr.unbox - (optimize_expr - { - ctx_enums = EnumName.Map.empty; - ctx_structs = StructName.Map.empty; - ctx_struct_fields = Ident.Map.empty; - ctx_scopes = ScopeName.Map.empty; - } - (Expr.unbox matchA)))) + (Expr.unbox (optimize_expr Program.empty_ctx (Expr.unbox matchA)))) diff --git a/compiler/shared_ast/print.ml b/compiler/shared_ast/print.ml index 1c2560d5..64c32589 100644 --- a/compiler/shared_ast/print.ml +++ b/compiler/shared_ast/print.ml @@ -379,6 +379,7 @@ module Precedence = struct | EOp _ -> Contained | EArray _ -> Contained | EVar _ -> Contained + | EExternal _ -> Contained | EAbs _ -> Abs | EIfThenElse _ -> Contained | EStruct _ -> Contained @@ -395,6 +396,7 @@ module Precedence = struct | EErrorOnEmpty _ -> App | ERaise _ -> App | ECatch _ -> App + | ECustom _ -> Contained let needs_parens ~context ?(rhs = false) e = match expr context, expr e with @@ -461,6 +463,7 @@ let rec expr_aux : let rhs ex = paren ~rhs:true ex in match Mark.remove e with | EVar v -> var fmt v + | EExternal eref -> Qident.format fmt eref | ETuple es -> Format.fprintf fmt "@[%a%a%a@]" punctuation "(" (Format.pp_print_list @@ -665,6 +668,7 @@ let rec expr_aux : Format.pp_close_box fmt (); punctuation fmt "}"; Format.pp_close_box fmt () + | ECustom _ -> Format.pp_print_string fmt "" let rec colors = let open Ocolor_types in diff --git a/compiler/shared_ast/program.ml b/compiler/shared_ast/program.ml index c14e3b07..c05f6798 100644 --- a/compiler/shared_ast/program.ml +++ b/compiler/shared_ast/program.ml @@ -28,6 +28,15 @@ let fold_left_exprs ~f ~init { code_items; decl_ctx = _ } = let fold_right_exprs ~f ~init { code_items; decl_ctx = _ } = Scope.fold_right ~f:(fun e _ acc -> f e acc) ~init code_items +let empty_ctx = + { + ctx_enums = EnumName.Map.empty; + ctx_structs = StructName.Map.empty; + ctx_struct_fields = Ident.Map.empty; + ctx_scopes = ScopeName.Map.empty; + ctx_modules = Qident.Map.empty; + } + let get_scope_body { code_items; _ } scope = match Scope.fold_left ~init:None diff --git a/compiler/shared_ast/program.mli b/compiler/shared_ast/program.mli index 5f636fdb..d1c8b704 100644 --- a/compiler/shared_ast/program.mli +++ b/compiler/shared_ast/program.mli @@ -17,6 +17,10 @@ open Definitions +(** {2 Program declaration context helpers} *) + +val empty_ctx : decl_ctx + (** {2 Transformations} *) val map_exprs : diff --git a/compiler/shared_ast/qident.ml b/compiler/shared_ast/qident.ml new file mode 100644 index 00000000..cebaa018 --- /dev/null +++ b/compiler/shared_ast/qident.ml @@ -0,0 +1,53 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2023 Inria, contributor: + Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +(** This module defines module names and path accesses, used to refer to + separate compilation units *) + +open Catala_utils + +type modname = string +type ident = string +type path = modname list +type t = path * ident + +let compare_path = List.compare String.compare +let equal_path = List.equal String.equal + +let compare (p1, i1) (p2, i2) = + match compare_path p1 p2 with 0 -> String.compare i1 i2 | n -> n + +let equal (p1, i1) (p2, i2) = equal_path p1 p2 && String.equal i1 i2 + +module Ord = struct + type nonrec t = t + + let compare = compare +end + +module Set = Set.Make (Ord) +module Map = Map.Make (Ord) + +let format_modname ppf m = Format.fprintf ppf "@{%s@}" m + +let format_path ppf p = + let pp_sep ppf () = Format.fprintf ppf "@{.@}" in + Format.pp_print_list ~pp_sep format_modname ppf p; + pp_sep ppf () + +let format ppf (p, i) = + format_path ppf p; + Format.pp_print_string ppf i diff --git a/compiler/shared_ast/qident.mli b/compiler/shared_ast/qident.mli new file mode 100644 index 00000000..62450319 --- /dev/null +++ b/compiler/shared_ast/qident.mli @@ -0,0 +1,36 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2023 Inria, contributor: + Louis Gesbert + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +(** This module defines module names and path accesses, used to refer to + separate compilation units. *) + +type modname = string +(** Expected to be a uident *) + +type ident = string +(** Expected to be a lident *) + +type path = modname list +type t = path * ident + +val compare_path : path -> path -> int +val equal_path : path -> path -> bool +val compare : t -> t -> int +val equal : t -> t -> bool +val format : Format.formatter -> t -> unit + +module Set : Set.S with type elt = t +module Map : Map.S with type key = t diff --git a/compiler/shared_ast/shared_ast.ml b/compiler/shared_ast/shared_ast.ml index 4d825107..692fd1eb 100644 --- a/compiler/shared_ast/shared_ast.ml +++ b/compiler/shared_ast/shared_ast.ml @@ -16,6 +16,7 @@ include Definitions module Var = Var +module Qident = Qident module Type = Type module Operator = Operator module Expr = Expr diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index 46abbb34..cb09bd5e 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -610,6 +610,16 @@ and typecheck_expr_top_down : "Variable %s not found in the current context" (Bindlib.name_of v) in Expr.evar (Var.translate v) (mark_with_tau_and_unify tau') + | A.EExternal eref -> + let ty = + try Qident.Map.find eref ctx.ctx_modules + with Not_found -> + Message.raise_spanned_error pos_e + "Could not resolve the reference to %a.@ Make sure the corresponding \ + module was properly loaded?" + Qident.format eref + in + Expr.eexternal eref (mark_with_tau_and_unify (ast_to_typ ty)) | A.ELit lit -> Expr.elit lit (ty_mark (lit_type lit)) | A.ETuple es -> let tys = List.map (fun _ -> unionfind (TAny (Any.fresh ()))) es in @@ -787,6 +797,11 @@ and typecheck_expr_top_down : List.map (typecheck_expr_top_down ~leave_unresolved ctx env cell_type) es in Expr.earray es' mark + | A.ECustom { obj; targs; tret } -> + let mark = + mark_with_tau_and_unify (ast_to_typ (A.TArrow (targs, tret), Expr.pos e)) + in + Expr.ecustom obj targs tret mark let wrap ctx f e = try f e diff --git a/compiler/surface/ast.ml b/compiler/surface/ast.ml index 9c22a447..53e38443 100644 --- a/compiler/surface/ast.ml +++ b/compiler/surface/ast.ml @@ -869,6 +869,8 @@ type law_structure = }] type program = { + program_interfaces : + ((Shared_ast.Qident.path[@opaque]) * code_item Mark.pos) list; program_items : law_structure list; program_source_files : (string[@opaque]) list; } diff --git a/compiler/surface/parser_driver.ml b/compiler/surface/parser_driver.ml index 39804cd1..fa5843e0 100644 --- a/compiler/surface/parser_driver.ml +++ b/compiler/surface/parser_driver.ml @@ -291,6 +291,7 @@ let rec parse_source_file (match input with Some input -> close_in input | None -> ()); let program = expand_includes source_file_name commands language in { + program_interfaces = []; program_items = program.Ast.program_items; program_source_files = source_file_name :: program.Ast.program_source_files; } @@ -309,6 +310,7 @@ and expand_includes let sub_source = Filename.concat source_dir (Mark.remove sub_source) in let includ_program = parse_source_file (FileName sub_source) language in { + program_interfaces = []; Ast.program_source_files = acc.Ast.program_source_files @ includ_program.program_source_files; Ast.program_items = @@ -316,22 +318,58 @@ and expand_includes } | Ast.LawHeading (heading, commands') -> let { + Ast.program_interfaces = _; Ast.program_items = commands'; Ast.program_source_files = new_sources; } = expand_includes source_file commands' language in { + Ast.program_interfaces = []; Ast.program_source_files = acc.Ast.program_source_files @ new_sources; Ast.program_items = acc.Ast.program_items @ [Ast.LawHeading (heading, commands')]; } | i -> { acc with Ast.program_items = acc.Ast.program_items @ [i] }) - { Ast.program_source_files = []; Ast.program_items = [] } + { + Ast.program_interfaces = []; + Ast.program_source_files = []; + Ast.program_items = []; + } commands (** {1 API} *) +let ext_id = "__external__" + +let add_interface source_file language path program = + let rec filter acc = function + | Ast.LawInclude _ -> acc + | Ast.LawHeading (_, str) -> List.fold_left filter acc str + | Ast.LawText _ -> acc + | Ast.CodeBlock (code, _, true) -> + List.fold_left + (fun acc -> function + | Ast.ScopeUse _, _ -> acc + | ((Ast.ScopeDecl _ | StructDecl _ | EnumDecl _), _) as e -> + (path, e) :: acc + | Ast.Topdef def, m -> + ( path, + ( Ast.Topdef + { def with topdef_expr = Ast.Ident ([], (ext_id, m)), m }, + m ) ) + :: acc) + acc code + | Ast.CodeBlock (_, _, false) -> + (* Non-metadata blocks are ignored *) + acc + in + let program_interfaces = + List.fold_left filter program.Ast.program_interfaces + (parse_source_file source_file language).Ast.program_items + in + { program with Ast.program_interfaces } + let parse_top_level_file (source_file : Pos.input_file) (language : Cli.backend_lang) : Ast.program = diff --git a/compiler/surface/parser_driver.mli b/compiler/surface/parser_driver.mli index d889ba12..e0abf5f8 100644 --- a/compiler/surface/parser_driver.mli +++ b/compiler/surface/parser_driver.mli @@ -19,4 +19,13 @@ open Catala_utils +val add_interface : + Pos.input_file -> + Cli.backend_lang -> + Shared_ast.Qident.path -> + Ast.program -> + Ast.program +(** Reads only declarations in metadata in the supplied input file, and add them + to the given program *) + val parse_top_level_file : Pos.input_file -> Cli.backend_lang -> Ast.program diff --git a/compiler/verification/z3backend.real.ml b/compiler/verification/z3backend.real.ml index d21623bf..ce4ce806 100644 --- a/compiler/verification/z3backend.real.ml +++ b/compiler/verification/z3backend.real.ml @@ -656,6 +656,7 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr = of a match. It actually corresponds to applying an accessor to an enum, the corresponding Z3 expression was previously stored in the context *) ctx, e) + | EExternal _ -> failwith "[Z3 encoding] EExternal unsupported" | EStruct _ -> failwith "[Z3 encoding] EStruct unsupported" | EStructAccess { e; field; name } -> let ctx, z3_struct = find_or_create_struct ctx name in diff --git a/runtimes/ocaml/runtime.ml b/runtimes/ocaml/runtime.ml index def5889e..7f1923c0 100644 --- a/runtimes/ocaml/runtime.ml +++ b/runtimes/ocaml/runtime.ml @@ -737,3 +737,23 @@ module Oper = struct end include Oper + +type hash = string + +let modules_table : (string, hash) Hashtbl.t = Hashtbl.create 13 +let values_table : (string list * string, Obj.t) Hashtbl.t = Hashtbl.create 13 + +let register_module modname values hash = + Hashtbl.add modules_table modname hash; + List.iter (fun (id, v) -> Hashtbl.add values_table ([modname], id) v) values + +let check_module m h = String.equal (Hashtbl.find modules_table m) h + +let lookup_value qid = + try Hashtbl.find values_table qid + with Not_found -> + failwith + ("Could not resolve reference to " + ^ String.concat "." (fst qid) + ^ "." + ^ snd qid) diff --git a/runtimes/ocaml/runtime.mli b/runtimes/ocaml/runtime.mli index 1aec4e8b..7acb564b 100644 --- a/runtimes/ocaml/runtime.mli +++ b/runtimes/ocaml/runtime.mli @@ -385,3 +385,21 @@ module Oper : sig end include module type of Oper + +(** Modules API *) + +type hash = string + +val register_module : string -> (string * Obj.t) list -> hash -> unit +(** Registers a module by the given name defining the given bindings. Required + for evaluation to be able to access the given values. The last argument is + expected to be a hash of the source file and the Catala version, and will in + time be used to ensure that the module and the interface are in sync *) + +val check_module : string -> hash -> bool +(** Returns [true] if it has been registered with the correct hash, [false] if + there is a hash mismatch. + + @raise Not_found if the module does not exist at all *) + +val lookup_value : string list * string -> Obj.t diff --git a/tests/test_scope/good/191_fix_record_name_confusion.catala_en b/tests/test_scope/good/191_fix_record_name_confusion.catala_en index e3cc4b66..5f364ed6 100644 --- a/tests/test_scope/good/191_fix_record_name_confusion.catala_en +++ b/tests/test_scope/good/191_fix_record_name_confusion.catala_en @@ -17,12 +17,14 @@ scope ScopeB: ```catala-test-inline $ catala OCaml -O + (** This file has been generated by the Catala compiler, do not edit! *) open Runtime_ocaml.Runtime [@@@ocaml.warning "-4-26-27-32-41-42"] + module ScopeA = struct type t = {a: bool} end @@ -58,4 +60,9 @@ let scope_b (scope_b_in: ScopeBIn.t) : ScopeB.t = start_line=8; start_column=10; end_line=8; end_column=11; law_headings=["Article"]})) in {ScopeB.a = a_} +let () = + Runtime_ocaml.Runtime.register_module "191_fix_record_name_confusion" + [ "ScopeA", Obj.repr scope_a; + "ScopeB", Obj.repr scope_b ] + "todo-module-hash" ```