Add experimental lazy interpreter as a plugin

To try it (without installing Catala):
```shell-session
$ make plugins
$ export CATALA_PLUGINS=_build/default/compiler/plugins
$ dune exec -- catala lazy examples/aides_logement/tests/tests_calcul_apl_locatif.catala_fr -s Exemple2
```

Keep in mind that this is a work-in-progress prototype :)
This commit is contained in:
Louis Gesbert 2023-04-14 16:56:57 +02:00
parent feeee4016e
commit b4a68fa392
8 changed files with 334 additions and 46 deletions

View File

@ -293,7 +293,11 @@ let driver source_file (options : Cli.options) : int =
let output_file, _ = get_output_format ~ext:p.Plugin.extension () in
Cli.debug_print "Compiling program through backend \"%s\"..."
p.Plugin.name;
p.Plugin.apply ~source_file ~output_file ~scope:options.ex_scope
p.Plugin.apply ~source_file ~output_file
~scope:
(match options.ex_scope with
| None -> None
| Some _ -> Some scope_uid)
(Shared_ast.Program.untype prgm)
type_ordering
| (`OCaml | `Interpret_Lcalc | `Python | `Lcalc | `Scalc | `Plugin _)
@ -389,7 +393,11 @@ let driver source_file (options : Cli.options) : int =
in
Cli.debug_print "Compiling program through backend \"%s\"..."
p.Plugin.name;
p.Plugin.apply ~source_file ~output_file ~scope:options.ex_scope
p.Plugin.apply ~source_file ~output_file
~scope:
(match options.ex_scope with
| None -> None
| Some _ -> Some scope_uid)
prgm type_ordering
| (`Python | `Scalc | `Plugin (Plugin.Scalc _)) as backend -> (
let prgm = Scalc.From_lcalc.translate_program prgm in
@ -427,7 +435,11 @@ let driver source_file (options : Cli.options) : int =
Cli.debug_print "Writing to %s..."
(Option.value ~default:"stdout" output_file);
p.Plugin.apply ~source_file ~output_file
~scope:options.ex_scope prgm type_ordering)))))));
~scope:
(match options.ex_scope with
| None -> None
| Some _ -> Some scope_uid)
prgm type_ordering)))))));
0
with
| Errors.StructuredError (msg, pos) ->

View File

@ -19,7 +19,7 @@ open Catala_utils
type 'ast plugin_apply_fun_typ =
source_file:Pos.input_file ->
output_file:string option ->
scope:string option ->
scope:Shared_ast.ScopeName.t option ->
'ast ->
Scopelang.Dependency.TVertex.t list ->
unit

View File

@ -21,7 +21,7 @@ open Catala_utils
type 'ast plugin_apply_fun_typ =
source_file:Pos.input_file ->
output_file:string option ->
scope:string option ->
scope:Shared_ast.ScopeName.t option ->
'ast ->
Scopelang.Dependency.TVertex.t list ->
unit

View File

@ -20,6 +20,14 @@
(modules json_schema)
(libraries catala.driver))
(library
(name lazy_interpreter)
(public_name catala.plugins.lazy-interpreter)
(synopsis
"Catala plugin that implements a different, experimental interpreter, featuring lazy and partial evaluation")
(modules lazy_interp)
(libraries shared_ast catala.driver))
(documentation
(package catala)
(mld_files plugins))

View File

@ -22,7 +22,6 @@ let extension = "_schema.json"
open Catala_utils
open Shared_ast
open Lcalc.Ast
open Lcalc.To_ocaml
module D = Dcalc.Ast
@ -47,17 +46,6 @@ module To_json = struct
in
Format.fprintf fmt "%s" s
let rec find_scope_def (target_name : string) :
'm expr code_item_list -> (ScopeName.t * 'm expr scope_body) option =
function
| Nil -> None
| Cons (ScopeDef (name, body), _)
when String.equal target_name (Marked.unmark (ScopeName.get_info name)) ->
Some (name, body)
| Cons (_, next_bind) ->
let _, next_scope = Bindlib.unbind next_bind in
find_scope_def target_name next_scope
let fmt_tlit fmt (tlit : typ_lit) =
match tlit with
| TUnit -> Format.fprintf fmt "\"type\": \"null\",@\n\"default\": null"
@ -203,31 +191,29 @@ module To_json = struct
let format_program
(fmt : Format.formatter)
(scope : string)
(scope : ScopeName.t)
(prgm : 'm Lcalc.Ast.program) =
match find_scope_def scope prgm.code_items with
| None -> Cli.error_print "Internal error: scope '%s' not found." scope
| Some scope_def ->
Cli.call_unstyled (fun _ ->
Format.fprintf fmt
"{@[<hov 2>@\n\
\"type\": \"object\",@\n\
\"@[<hov 2>definitions\": {%a@]@\n\
},@\n\
\"@[<hov 2>properties\": {@\n\
%a@]@\n\
}@]@\n\
}"
(fmt_definitions prgm.decl_ctx)
scope_def
(fmt_struct_properties prgm.decl_ctx)
(snd scope_def).scope_body_input_struct)
let scope_body = Program.get_scope_body prgm scope in
Cli.call_unstyled (fun _ ->
Format.fprintf fmt
"{@[<hov 2>@\n\
\"type\": \"object\",@\n\
\"@[<hov 2>definitions\": {%a@]@\n\
},@\n\
\"@[<hov 2>properties\": {@\n\
%a@]@\n\
}@]@\n\
}"
(fmt_definitions prgm.decl_ctx)
(scope, scope_body)
(fmt_struct_properties prgm.decl_ctx)
scope_body.scope_body_input_struct)
end
let apply
~(source_file : Pos.input_file)
~(output_file : string option)
~(scope : string option)
~(scope : Shared_ast.ScopeName.t option)
(prgm : 'm Lcalc.Ast.program)
(type_ordering : Scopelang.Dependency.TVertex.t list) =
ignore source_file;
@ -236,9 +222,9 @@ let apply
| Some s ->
File.with_formatter_of_opt_file output_file (fun fmt ->
Cli.debug_print
"Writing JSON schema corresponding to the scope '%s' to the file \
"Writing JSON schema corresponding to the scope '%a' to the file \
%s..."
s
ScopeName.format_t s
(Option.value ~default:"stdout" output_file);
To_json.format_program fmt s prgm)
| None -> Cli.error_print "A scope must be specified for the plugin: %s" name

View File

@ -0,0 +1,273 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2023 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>.
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Shared_ast
(* -- Definition of the lazy interpreter -- *)
let log fmt = Format.ifprintf Format.err_formatter (fmt ^^ "@\n")
let error e = Errors.raise_spanned_error (Expr.pos e)
let noassert = true
type laziness_level = {
eval_struct : bool;
(* if true, evaluate members of structures, tuples, etc. *)
eval_op : bool;
(* if false, evaluate the operands but keep e.g. `3 + 4` as is *)
eval_default : bool;
(* if false, stop evaluating as soon as you can discriminate with
`EEmptyError` *)
}
let value_level = { eval_struct = false; eval_op = true; eval_default = true }
module Env = struct
type 'm t =
| Env of
((dcalc, 'm mark) gexpr, ((dcalc, 'm mark) gexpr * 'm t) ref) Var.Map.t
let find v (Env t) = Var.Map.find v t
let add v e e_env (Env t) = Env (Var.Map.add v (ref (e, e_env)) t)
let empty = Env Var.Map.empty
let join (Env t1) (Env t2) =
Env
(Var.Map.union
(fun _ x1 x2 ->
assert (x1 == x2);
Some x1)
t1 t2)
let print ppf (Env t) =
Format.pp_print_list ~pp_sep:Format.pp_print_space
(fun ppf (v, { contents = _e, _env }) -> Print.var_debug ppf v)
ppf (Var.Map.bindings t)
end
let rec lazy_eval :
decl_ctx ->
'm Env.t ->
laziness_level ->
(dcalc, 'm mark) gexpr ->
(dcalc, 'm mark) gexpr * 'm Env.t =
fun ctx env llevel e0 ->
let eval_to_value ?(eval_default = true) env e =
lazy_eval ctx env { value_level with eval_default } e
in
match e0 with
| EVar v, _ ->
if not llevel.eval_default then e0, env
else
(* Variables reducing to EEmpty should not propagate to parent EDefault
(?) *)
let v_env =
try Env.find v env
with Not_found ->
error e0 "Variable %a undefined [@[<hv>%a@]]" Print.var_debug v
Env.print env
in
let e, env1 = !v_env in
let r, env1 = lazy_eval ctx env1 llevel e in
if not (Expr.equal e r) then (
log "@[<hv 2>{{%a =@ [%a]@ ==> [%a]}}@]" Print.var_debug v
(Print.expr ~debug:true ctx)
e
(Print.expr ~debug:true ctx)
r;
v_env := r, env1);
r, Env.join env env1
| EApp { f; args }, m -> (
if
(not llevel.eval_default)
&& not (List.equal Expr.equal args [ELit LUnit, m])
(* Applications to () encode thunked default terms *)
then e0, env
else
match eval_to_value env f with
| (EAbs { binder; _ }, _), env ->
let vars, body = Bindlib.unmbind binder in
log "@[<v 2>@[<hov 4>{";
let env =
Seq.fold_left2
(fun env1 var e ->
log "@[<hov 2>LET %a = %a@]@ " Print.var_debug var
(Print.expr ~debug:true ctx)
e;
Env.add var e env env1)
env (Array.to_seq vars) (List.to_seq args)
in
log "@]@[<hov 4>IN [%a]@]" (Print.expr ~debug:true ctx) body;
let e, env = lazy_eval ctx env llevel body in
log "@]}";
e, env
| ((EOp { op; _ }, m) as f), env ->
let env, args =
List.fold_left_map
(fun env e ->
let e, env = lazy_eval ctx env llevel e in
env, e)
env args
in
if not llevel.eval_op then (EApp { f; args }, m), env
else
let renv = ref env in
(* Dirty workaround returning env from evaluate_operator *)
let eval e =
let e, env = lazy_eval ctx !renv llevel e in
renv := env;
e
in
Interpreter.evaluate_operator eval ctx op m args, !renv
(* fixme: this forwards eempty *)
| e, _ -> error e "Invalid apply on %a" (Print.expr ctx) e)
| (EAbs _ | ELit _ | EOp _ | EEmptyError), _ -> e0, env (* these are values *)
| (EStruct _ | ETuple _ | EInj _ | EArray _), _ ->
if not llevel.eval_struct then e0, env
else
let env, e =
Expr.map_gather ~acc:env ~join:Env.join
~f:(fun e ->
let e, env = lazy_eval ctx env llevel e in
env, Expr.box e)
e0
in
Expr.unbox e, env
| EStructAccess { e; name; field }, _ -> (
if not llevel.eval_default then e0, env
else
match eval_to_value env e with
| (EStruct { name = n; fields }, _), env when StructName.equal name n ->
lazy_eval ctx env llevel (StructField.Map.find field fields)
| e, _ -> error e "Invalid field access on %a" (Print.expr ctx) e)
| ETupleAccess { e; index; size }, _ -> (
if not llevel.eval_default then e0, env
else
match eval_to_value env e with
| (ETuple es, _), env when List.length es = size ->
lazy_eval ctx env llevel (List.nth es index)
| e, _ -> error e "Invalid tuple access on %a" (Print.expr ctx) e)
| EMatch { e; name; cases }, _ -> (
if not llevel.eval_default then e0, env
else
match eval_to_value env e with
| (EInj { name = n; cons; e }, m), env when EnumName.equal name n ->
lazy_eval ctx env llevel
(EApp { f = EnumConstructor.Map.find cons cases; args = [e] }, m)
| e, _ -> error e "Invalid match argument %a" (Print.expr ctx) e)
| EDefault { excepts; just; cons }, m -> (
let excs =
List.filter_map
(fun e ->
match eval_to_value env e ~eval_default:false with
| (EEmptyError, _), _ -> None
| e -> Some e)
excepts
in
match excs with
| [] -> (
match eval_to_value env just with
| (ELit (LBool true), _), _ -> lazy_eval ctx env llevel cons
| (ELit (LBool false), _), _ -> (EEmptyError, m), env
| e, _ -> error e "Invalid exception justification %a" (Print.expr ctx) e)
| [(e, env)] ->
log "@[<hov 5>EVAL %a@]" (Print.expr ctx) e;
lazy_eval ctx env llevel e
| _ :: _ :: _ ->
Errors.raise_multispanned_error
((None, Expr.mark_pos m)
:: List.map (fun (e, _) -> None, Expr.pos e) excs)
"Conflicting exceptions")
| EIfThenElse { cond; etrue; efalse }, _ -> (
match eval_to_value env cond with
| (ELit (LBool true), _), _ -> lazy_eval ctx env llevel etrue
| (ELit (LBool false), _), _ -> lazy_eval ctx env llevel efalse
| e, _ -> error e "Invalid condition %a" (Print.expr ctx) e)
| EErrorOnEmpty e, _ -> (
match eval_to_value env e ~eval_default:false with
| ((EEmptyError, _) as e'), _ ->
(* This does _not_ match the eager semantics ! *)
error e' "This value is undefined %a" (Print.expr ctx) e
| e, env -> lazy_eval ctx env llevel e)
| EAssert e, m -> (
if noassert then (ELit LUnit, m), env
else
match eval_to_value env e with
| (ELit (LBool true), m), env -> (ELit LUnit, m), env
| (ELit (LBool false), _), _ ->
error e "Assert failure (%a)" (Print.expr ctx) e
| _ -> error e "Invalid assertion condition %a" (Print.expr ctx) e)
| _ -> .
let interpret_program
(prg : ('dcalc, 'm mark) gexpr program)
(scope : ScopeName.t) : ('t, 'm mark) gexpr * 'm Env.t =
let ctx = prg.decl_ctx in
let all_env, scopes =
Scope.fold_left prg.code_items ~init:(Env.empty, ScopeName.Map.empty)
~f:(fun (env, scopes) item v ->
match item with
| ScopeDef (name, body) ->
let e = Scope.to_expr ctx body (Scope.get_body_mark body) in
( Env.add v (Expr.unbox e) env env,
ScopeName.Map.add name (v, body.scope_body_input_struct) scopes )
| Topdef (_, _, e) -> Env.add v e env env, scopes)
in
let scope_v, scope_arg_struct = ScopeName.Map.find scope scopes in
let { contents = e, env } = Env.find scope_v all_env in
let e = Expr.unbox (Expr.remove_logging_calls e) in
log "=====================";
log "%a" (Print.expr ~debug:true ctx) e;
log "=====================";
let m = Marked.get_mark e in
let application_arg =
Expr.estruct scope_arg_struct
(StructField.Map.map
(function
| TArrow (ty_in, ty_out), _ ->
Expr.make_abs
[| Var.make "_" |]
(Bindlib.box EEmptyError, Expr.with_ty m ty_out)
ty_in (Expr.mark_pos m)
| ty -> Expr.evar (Var.make "undefined_input") (Expr.with_ty m ty))
(StructName.Map.find scope_arg_struct ctx.ctx_structs))
m
in
let e_app = Expr.eapp (Expr.box e) [application_arg] m in
lazy_eval ctx env
{ value_level with eval_struct = true; eval_op = false }
(Expr.unbox e_app)
(* -- Plugin registration -- *)
let name = "lazy"
let extension = ".out" (* unused *)
let apply ~source_file ~output_file ~scope prg _type_ordering =
let scope =
match scope with
| None -> Errors.raise_error "A scope must be specified"
| Some s -> s
in
ignore source_file;
(* File.with_formatter_of_opt_file output_file
* @@ fun fmt -> *)
ignore output_file;
let fmt = Format.std_formatter in
let result_expr, _env = interpret_program prg scope in
Print.expr prg.decl_ctx fmt result_expr
let () = Driver.Plugin.register_dcalc ~name ~extension apply

View File

@ -191,15 +191,14 @@ let rec evaluate_operator
EArray
(List.map
(fun e' ->
evaluate_expr ctx (Marked.same_mark_as (EApp { f; args = [e'] }) e'))
evaluate_expr (Marked.same_mark_as (EApp { f; args = [e'] }) e'))
es)
| Reduce, [_; default; (EArray [], _)] -> Marked.unmark default
| Reduce, [f; _; (EArray (x0 :: xn), _)] ->
Marked.unmark
(List.fold_left
(fun acc x ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [acc; x] }) f))
evaluate_expr (Marked.same_mark_as (EApp { f; args = [acc; x] }) f))
x0 xn)
| Concat, [(EArray es1, _); (EArray es2, _)] -> EArray (es1 @ es2)
| Filter, [f; (EArray es, _)] ->
@ -207,8 +206,7 @@ let rec evaluate_operator
(List.filter
(fun e' ->
match
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [e'] }) e')
evaluate_expr (Marked.same_mark_as (EApp { f; args = [e'] }) e')
with
| ELit (LBool b), _ -> b
| _ ->
@ -221,8 +219,7 @@ let rec evaluate_operator
Marked.unmark
(List.fold_left
(fun acc e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [acc; e'] }) e'))
evaluate_expr (Marked.same_mark_as (EApp { f; args = [acc; e'] }) e'))
init es)
| (Length | Log _ | Eq | Map | Concat | Filter | Fold | Reduce), _ -> err ()
| Not, [(ELit (LBool b), _)] -> ELit (LBool (o_not b))
@ -409,7 +406,7 @@ let rec evaluate_expr :
| Eq_dur_dur ) as op;
_;
} ->
evaluate_operator evaluate_expr ctx op m args
evaluate_operator (evaluate_expr ctx) ctx op m args
| _ ->
Errors.raise_spanned_error pos
"function has not been reduced to a lambda at evaluation (should not \

View File

@ -20,6 +20,18 @@
open Catala_utils
open Definitions
val evaluate_operator :
((([< all ] as 'a), 'm mark) gexpr -> ('a, 'm mark) gexpr) ->
decl_ctx ->
[< dcalc | lcalc > `Monomorphic `Polymorphic `Resolved ] operator ->
'm mark ->
('a, 'm mark) gexpr list ->
('a, 'm mark) gexpr
(** Evaluates the result of applying the given operator to the given arguments,
which are expected to be already reduced to values. The first argument is
used to evaluate expressions and called when reducing e.g. the [map]
operator. *)
val evaluate_expr :
decl_ctx -> (([< dcalc | lcalc ] as 'a), 'm mark) gexpr -> ('a, 'm mark) gexpr
(** Evaluates an expression according to the semantics of the default calculus. *)