diff --git a/compiler/driver.ml b/compiler/driver.ml index 3459b16c..a2260256 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -293,7 +293,11 @@ let driver source_file (options : Cli.options) : int = let output_file, _ = get_output_format ~ext:p.Plugin.extension () in Cli.debug_print "Compiling program through backend \"%s\"..." p.Plugin.name; - p.Plugin.apply ~source_file ~output_file ~scope:options.ex_scope + p.Plugin.apply ~source_file ~output_file + ~scope: + (match options.ex_scope with + | None -> None + | Some _ -> Some scope_uid) (Shared_ast.Program.untype prgm) type_ordering | (`OCaml | `Interpret_Lcalc | `Python | `Lcalc | `Scalc | `Plugin _) @@ -389,7 +393,11 @@ let driver source_file (options : Cli.options) : int = in Cli.debug_print "Compiling program through backend \"%s\"..." p.Plugin.name; - p.Plugin.apply ~source_file ~output_file ~scope:options.ex_scope + p.Plugin.apply ~source_file ~output_file + ~scope: + (match options.ex_scope with + | None -> None + | Some _ -> Some scope_uid) prgm type_ordering | (`Python | `Scalc | `Plugin (Plugin.Scalc _)) as backend -> ( let prgm = Scalc.From_lcalc.translate_program prgm in @@ -427,7 +435,11 @@ let driver source_file (options : Cli.options) : int = Cli.debug_print "Writing to %s..." (Option.value ~default:"stdout" output_file); p.Plugin.apply ~source_file ~output_file - ~scope:options.ex_scope prgm type_ordering))))))); + ~scope: + (match options.ex_scope with + | None -> None + | Some _ -> Some scope_uid) + prgm type_ordering))))))); 0 with | Errors.StructuredError (msg, pos) -> diff --git a/compiler/plugin.ml b/compiler/plugin.ml index 83a149a4..8c5540ac 100644 --- a/compiler/plugin.ml +++ b/compiler/plugin.ml @@ -19,7 +19,7 @@ open Catala_utils type 'ast plugin_apply_fun_typ = source_file:Pos.input_file -> output_file:string option -> - scope:string option -> + scope:Shared_ast.ScopeName.t option -> 'ast -> Scopelang.Dependency.TVertex.t list -> unit diff --git a/compiler/plugin.mli b/compiler/plugin.mli index 0d69561a..c8046201 100644 --- a/compiler/plugin.mli +++ b/compiler/plugin.mli @@ -21,7 +21,7 @@ open Catala_utils type 'ast plugin_apply_fun_typ = source_file:Pos.input_file -> output_file:string option -> - scope:string option -> + scope:Shared_ast.ScopeName.t option -> 'ast -> Scopelang.Dependency.TVertex.t list -> unit diff --git a/compiler/plugins/dune b/compiler/plugins/dune index 531a170f..0d768020 100644 --- a/compiler/plugins/dune +++ b/compiler/plugins/dune @@ -20,6 +20,14 @@ (modules json_schema) (libraries catala.driver)) +(library + (name lazy_interpreter) + (public_name catala.plugins.lazy-interpreter) + (synopsis + "Catala plugin that implements a different, experimental interpreter, featuring lazy and partial evaluation") + (modules lazy_interp) + (libraries shared_ast catala.driver)) + (documentation (package catala) (mld_files plugins)) diff --git a/compiler/plugins/json_schema.ml b/compiler/plugins/json_schema.ml index e0f8ce02..bd7b783f 100644 --- a/compiler/plugins/json_schema.ml +++ b/compiler/plugins/json_schema.ml @@ -22,7 +22,6 @@ let extension = "_schema.json" open Catala_utils open Shared_ast -open Lcalc.Ast open Lcalc.To_ocaml module D = Dcalc.Ast @@ -47,17 +46,6 @@ module To_json = struct in Format.fprintf fmt "%s" s - let rec find_scope_def (target_name : string) : - 'm expr code_item_list -> (ScopeName.t * 'm expr scope_body) option = - function - | Nil -> None - | Cons (ScopeDef (name, body), _) - when String.equal target_name (Marked.unmark (ScopeName.get_info name)) -> - Some (name, body) - | Cons (_, next_bind) -> - let _, next_scope = Bindlib.unbind next_bind in - find_scope_def target_name next_scope - let fmt_tlit fmt (tlit : typ_lit) = match tlit with | TUnit -> Format.fprintf fmt "\"type\": \"null\",@\n\"default\": null" @@ -203,31 +191,29 @@ module To_json = struct let format_program (fmt : Format.formatter) - (scope : string) + (scope : ScopeName.t) (prgm : 'm Lcalc.Ast.program) = - match find_scope_def scope prgm.code_items with - | None -> Cli.error_print "Internal error: scope '%s' not found." scope - | Some scope_def -> - Cli.call_unstyled (fun _ -> - Format.fprintf fmt - "{@[@\n\ - \"type\": \"object\",@\n\ - \"@[definitions\": {%a@]@\n\ - },@\n\ - \"@[properties\": {@\n\ - %a@]@\n\ - }@]@\n\ - }" - (fmt_definitions prgm.decl_ctx) - scope_def - (fmt_struct_properties prgm.decl_ctx) - (snd scope_def).scope_body_input_struct) + let scope_body = Program.get_scope_body prgm scope in + Cli.call_unstyled (fun _ -> + Format.fprintf fmt + "{@[@\n\ + \"type\": \"object\",@\n\ + \"@[definitions\": {%a@]@\n\ + },@\n\ + \"@[properties\": {@\n\ + %a@]@\n\ + }@]@\n\ + }" + (fmt_definitions prgm.decl_ctx) + (scope, scope_body) + (fmt_struct_properties prgm.decl_ctx) + scope_body.scope_body_input_struct) end let apply ~(source_file : Pos.input_file) ~(output_file : string option) - ~(scope : string option) + ~(scope : Shared_ast.ScopeName.t option) (prgm : 'm Lcalc.Ast.program) (type_ordering : Scopelang.Dependency.TVertex.t list) = ignore source_file; @@ -236,9 +222,9 @@ let apply | Some s -> File.with_formatter_of_opt_file output_file (fun fmt -> Cli.debug_print - "Writing JSON schema corresponding to the scope '%s' to the file \ + "Writing JSON schema corresponding to the scope '%a' to the file \ %s..." - s + ScopeName.format_t s (Option.value ~default:"stdout" output_file); To_json.format_program fmt s prgm) | None -> Cli.error_print "A scope must be specified for the plugin: %s" name diff --git a/compiler/plugins/lazy_interp.ml b/compiler/plugins/lazy_interp.ml new file mode 100644 index 00000000..1d854242 --- /dev/null +++ b/compiler/plugins/lazy_interp.ml @@ -0,0 +1,273 @@ +(* This file is part of the Catala compiler, a specification language for tax + and social benefits computation rules. Copyright (C) 2023 Inria, contributor: + Louis Gesbert . + + Licensed under the Apache License, Version 2.0 (the "License"); you may not + use this file except in compliance with the License. You may obtain a copy of + the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + License for the specific language governing permissions and limitations under + the License. *) + +open Catala_utils +open Shared_ast + +(* -- Definition of the lazy interpreter -- *) + +let log fmt = Format.ifprintf Format.err_formatter (fmt ^^ "@\n") +let error e = Errors.raise_spanned_error (Expr.pos e) +let noassert = true + +type laziness_level = { + eval_struct : bool; + (* if true, evaluate members of structures, tuples, etc. *) + eval_op : bool; + (* if false, evaluate the operands but keep e.g. `3 + 4` as is *) + eval_default : bool; + (* if false, stop evaluating as soon as you can discriminate with + `EEmptyError` *) +} + +let value_level = { eval_struct = false; eval_op = true; eval_default = true } + +module Env = struct + type 'm t = + | Env of + ((dcalc, 'm mark) gexpr, ((dcalc, 'm mark) gexpr * 'm t) ref) Var.Map.t + + let find v (Env t) = Var.Map.find v t + let add v e e_env (Env t) = Env (Var.Map.add v (ref (e, e_env)) t) + let empty = Env Var.Map.empty + + let join (Env t1) (Env t2) = + Env + (Var.Map.union + (fun _ x1 x2 -> + assert (x1 == x2); + Some x1) + t1 t2) + + let print ppf (Env t) = + Format.pp_print_list ~pp_sep:Format.pp_print_space + (fun ppf (v, { contents = _e, _env }) -> Print.var_debug ppf v) + ppf (Var.Map.bindings t) +end + +let rec lazy_eval : + decl_ctx -> + 'm Env.t -> + laziness_level -> + (dcalc, 'm mark) gexpr -> + (dcalc, 'm mark) gexpr * 'm Env.t = + fun ctx env llevel e0 -> + let eval_to_value ?(eval_default = true) env e = + lazy_eval ctx env { value_level with eval_default } e + in + match e0 with + | EVar v, _ -> + if not llevel.eval_default then e0, env + else + (* Variables reducing to EEmpty should not propagate to parent EDefault + (?) *) + let v_env = + try Env.find v env + with Not_found -> + error e0 "Variable %a undefined [@[%a@]]" Print.var_debug v + Env.print env + in + let e, env1 = !v_env in + let r, env1 = lazy_eval ctx env1 llevel e in + if not (Expr.equal e r) then ( + log "@[{{%a =@ [%a]@ ==> [%a]}}@]" Print.var_debug v + (Print.expr ~debug:true ctx) + e + (Print.expr ~debug:true ctx) + r; + v_env := r, env1); + r, Env.join env env1 + | EApp { f; args }, m -> ( + if + (not llevel.eval_default) + && not (List.equal Expr.equal args [ELit LUnit, m]) + (* Applications to () encode thunked default terms *) + then e0, env + else + match eval_to_value env f with + | (EAbs { binder; _ }, _), env -> + let vars, body = Bindlib.unmbind binder in + log "@[@[{"; + let env = + Seq.fold_left2 + (fun env1 var e -> + log "@[LET %a = %a@]@ " Print.var_debug var + (Print.expr ~debug:true ctx) + e; + Env.add var e env env1) + env (Array.to_seq vars) (List.to_seq args) + in + log "@]@[IN [%a]@]" (Print.expr ~debug:true ctx) body; + let e, env = lazy_eval ctx env llevel body in + log "@]}"; + e, env + | ((EOp { op; _ }, m) as f), env -> + let env, args = + List.fold_left_map + (fun env e -> + let e, env = lazy_eval ctx env llevel e in + env, e) + env args + in + if not llevel.eval_op then (EApp { f; args }, m), env + else + let renv = ref env in + (* Dirty workaround returning env from evaluate_operator *) + let eval e = + let e, env = lazy_eval ctx !renv llevel e in + renv := env; + e + in + Interpreter.evaluate_operator eval ctx op m args, !renv + (* fixme: this forwards eempty *) + | e, _ -> error e "Invalid apply on %a" (Print.expr ctx) e) + | (EAbs _ | ELit _ | EOp _ | EEmptyError), _ -> e0, env (* these are values *) + | (EStruct _ | ETuple _ | EInj _ | EArray _), _ -> + if not llevel.eval_struct then e0, env + else + let env, e = + Expr.map_gather ~acc:env ~join:Env.join + ~f:(fun e -> + let e, env = lazy_eval ctx env llevel e in + env, Expr.box e) + e0 + in + Expr.unbox e, env + | EStructAccess { e; name; field }, _ -> ( + if not llevel.eval_default then e0, env + else + match eval_to_value env e with + | (EStruct { name = n; fields }, _), env when StructName.equal name n -> + lazy_eval ctx env llevel (StructField.Map.find field fields) + | e, _ -> error e "Invalid field access on %a" (Print.expr ctx) e) + | ETupleAccess { e; index; size }, _ -> ( + if not llevel.eval_default then e0, env + else + match eval_to_value env e with + | (ETuple es, _), env when List.length es = size -> + lazy_eval ctx env llevel (List.nth es index) + | e, _ -> error e "Invalid tuple access on %a" (Print.expr ctx) e) + | EMatch { e; name; cases }, _ -> ( + if not llevel.eval_default then e0, env + else + match eval_to_value env e with + | (EInj { name = n; cons; e }, m), env when EnumName.equal name n -> + lazy_eval ctx env llevel + (EApp { f = EnumConstructor.Map.find cons cases; args = [e] }, m) + | e, _ -> error e "Invalid match argument %a" (Print.expr ctx) e) + | EDefault { excepts; just; cons }, m -> ( + let excs = + List.filter_map + (fun e -> + match eval_to_value env e ~eval_default:false with + | (EEmptyError, _), _ -> None + | e -> Some e) + excepts + in + match excs with + | [] -> ( + match eval_to_value env just with + | (ELit (LBool true), _), _ -> lazy_eval ctx env llevel cons + | (ELit (LBool false), _), _ -> (EEmptyError, m), env + | e, _ -> error e "Invalid exception justification %a" (Print.expr ctx) e) + | [(e, env)] -> + log "@[EVAL %a@]" (Print.expr ctx) e; + lazy_eval ctx env llevel e + | _ :: _ :: _ -> + Errors.raise_multispanned_error + ((None, Expr.mark_pos m) + :: List.map (fun (e, _) -> None, Expr.pos e) excs) + "Conflicting exceptions") + | EIfThenElse { cond; etrue; efalse }, _ -> ( + match eval_to_value env cond with + | (ELit (LBool true), _), _ -> lazy_eval ctx env llevel etrue + | (ELit (LBool false), _), _ -> lazy_eval ctx env llevel efalse + | e, _ -> error e "Invalid condition %a" (Print.expr ctx) e) + | EErrorOnEmpty e, _ -> ( + match eval_to_value env e ~eval_default:false with + | ((EEmptyError, _) as e'), _ -> + (* This does _not_ match the eager semantics ! *) + error e' "This value is undefined %a" (Print.expr ctx) e + | e, env -> lazy_eval ctx env llevel e) + | EAssert e, m -> ( + if noassert then (ELit LUnit, m), env + else + match eval_to_value env e with + | (ELit (LBool true), m), env -> (ELit LUnit, m), env + | (ELit (LBool false), _), _ -> + error e "Assert failure (%a)" (Print.expr ctx) e + | _ -> error e "Invalid assertion condition %a" (Print.expr ctx) e) + | _ -> . + +let interpret_program + (prg : ('dcalc, 'm mark) gexpr program) + (scope : ScopeName.t) : ('t, 'm mark) gexpr * 'm Env.t = + let ctx = prg.decl_ctx in + let all_env, scopes = + Scope.fold_left prg.code_items ~init:(Env.empty, ScopeName.Map.empty) + ~f:(fun (env, scopes) item v -> + match item with + | ScopeDef (name, body) -> + let e = Scope.to_expr ctx body (Scope.get_body_mark body) in + ( Env.add v (Expr.unbox e) env env, + ScopeName.Map.add name (v, body.scope_body_input_struct) scopes ) + | Topdef (_, _, e) -> Env.add v e env env, scopes) + in + let scope_v, scope_arg_struct = ScopeName.Map.find scope scopes in + let { contents = e, env } = Env.find scope_v all_env in + let e = Expr.unbox (Expr.remove_logging_calls e) in + log "====================="; + log "%a" (Print.expr ~debug:true ctx) e; + log "====================="; + let m = Marked.get_mark e in + let application_arg = + Expr.estruct scope_arg_struct + (StructField.Map.map + (function + | TArrow (ty_in, ty_out), _ -> + Expr.make_abs + [| Var.make "_" |] + (Bindlib.box EEmptyError, Expr.with_ty m ty_out) + ty_in (Expr.mark_pos m) + | ty -> Expr.evar (Var.make "undefined_input") (Expr.with_ty m ty)) + (StructName.Map.find scope_arg_struct ctx.ctx_structs)) + m + in + let e_app = Expr.eapp (Expr.box e) [application_arg] m in + lazy_eval ctx env + { value_level with eval_struct = true; eval_op = false } + (Expr.unbox e_app) + +(* -- Plugin registration -- *) + +let name = "lazy" +let extension = ".out" (* unused *) + +let apply ~source_file ~output_file ~scope prg _type_ordering = + let scope = + match scope with + | None -> Errors.raise_error "A scope must be specified" + | Some s -> s + in + ignore source_file; + (* File.with_formatter_of_opt_file output_file + * @@ fun fmt -> *) + ignore output_file; + let fmt = Format.std_formatter in + let result_expr, _env = interpret_program prg scope in + Print.expr prg.decl_ctx fmt result_expr + +let () = Driver.Plugin.register_dcalc ~name ~extension apply diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index e2e480a6..5296e951 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -191,15 +191,14 @@ let rec evaluate_operator EArray (List.map (fun e' -> - evaluate_expr ctx (Marked.same_mark_as (EApp { f; args = [e'] }) e')) + evaluate_expr (Marked.same_mark_as (EApp { f; args = [e'] }) e')) es) | Reduce, [_; default; (EArray [], _)] -> Marked.unmark default | Reduce, [f; _; (EArray (x0 :: xn), _)] -> Marked.unmark (List.fold_left (fun acc x -> - evaluate_expr ctx - (Marked.same_mark_as (EApp { f; args = [acc; x] }) f)) + evaluate_expr (Marked.same_mark_as (EApp { f; args = [acc; x] }) f)) x0 xn) | Concat, [(EArray es1, _); (EArray es2, _)] -> EArray (es1 @ es2) | Filter, [f; (EArray es, _)] -> @@ -207,8 +206,7 @@ let rec evaluate_operator (List.filter (fun e' -> match - evaluate_expr ctx - (Marked.same_mark_as (EApp { f; args = [e'] }) e') + evaluate_expr (Marked.same_mark_as (EApp { f; args = [e'] }) e') with | ELit (LBool b), _ -> b | _ -> @@ -221,8 +219,7 @@ let rec evaluate_operator Marked.unmark (List.fold_left (fun acc e' -> - evaluate_expr ctx - (Marked.same_mark_as (EApp { f; args = [acc; e'] }) e')) + evaluate_expr (Marked.same_mark_as (EApp { f; args = [acc; e'] }) e')) init es) | (Length | Log _ | Eq | Map | Concat | Filter | Fold | Reduce), _ -> err () | Not, [(ELit (LBool b), _)] -> ELit (LBool (o_not b)) @@ -409,7 +406,7 @@ let rec evaluate_expr : | Eq_dur_dur ) as op; _; } -> - evaluate_operator evaluate_expr ctx op m args + evaluate_operator (evaluate_expr ctx) ctx op m args | _ -> Errors.raise_spanned_error pos "function has not been reduced to a lambda at evaluation (should not \ diff --git a/compiler/shared_ast/interpreter.mli b/compiler/shared_ast/interpreter.mli index 4e5911bd..b311d837 100644 --- a/compiler/shared_ast/interpreter.mli +++ b/compiler/shared_ast/interpreter.mli @@ -20,6 +20,18 @@ open Catala_utils open Definitions +val evaluate_operator : + ((([< all ] as 'a), 'm mark) gexpr -> ('a, 'm mark) gexpr) -> + decl_ctx -> + [< dcalc | lcalc > `Monomorphic `Polymorphic `Resolved ] operator -> + 'm mark -> + ('a, 'm mark) gexpr list -> + ('a, 'm mark) gexpr +(** Evaluates the result of applying the given operator to the given arguments, + which are expected to be already reduced to values. The first argument is + used to evaluate expressions and called when reducing e.g. the [map] + operator. *) + val evaluate_expr : decl_ctx -> (([< dcalc | lcalc ] as 'a), 'm mark) gexpr -> ('a, 'm mark) gexpr (** Evaluates an expression according to the semantics of the default calculus. *)