mirror of
https://github.com/CatalaLang/catala.git
synced 2024-09-19 16:28:12 +03:00
Working partial evaluation for Dcalc using ugly but correct style
This commit is contained in:
parent
743a1b74c9
commit
ad4218285d
@ -14,144 +14,99 @@
|
||||
open Utils
|
||||
open Ast
|
||||
|
||||
let ( let+ ) x f = Bindlib.box_apply f x
|
||||
type partial_evaluation_ctx = expr Pos.marked Ast.VarMap.t
|
||||
|
||||
let ( and+ ) x y = Bindlib.box_pair x y
|
||||
|
||||
let visitor_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx : 'a)
|
||||
(e : expr Pos.marked) : expr Pos.marked Bindlib.box =
|
||||
(* calls [t ctx] on every direct childs of [e], then rebuild an abstract syntax tree modified.
|
||||
Used in other transformations. *)
|
||||
let default_mark e' = Pos.same_pos_as e' e in
|
||||
let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : expr Pos.marked) :
|
||||
expr Pos.marked Bindlib.box =
|
||||
let pos = Pos.get_position e in
|
||||
let rec_helper = partial_evaluation ctx in
|
||||
match Pos.unmark e with
|
||||
| EVar (v, pos) ->
|
||||
let+ v = Bindlib.box_var v in
|
||||
(v, pos)
|
||||
| ETuple (args, n) ->
|
||||
let+ args = args |> List.map (t ctx) |> Bindlib.box_list in
|
||||
default_mark @@ ETuple (args, n)
|
||||
| ETupleAccess (e1, i, n, ts) ->
|
||||
let+ e1 = t ctx e1 in
|
||||
default_mark @@ ETupleAccess (e1, i, n, ts)
|
||||
| EInj (e1, i, n, ts) ->
|
||||
let+ e1 = t ctx e1 in
|
||||
default_mark @@ EInj (e1, i, n, ts)
|
||||
| EMatch (arg, cases, n) ->
|
||||
let+ arg = t ctx arg and+ cases = cases |> List.map (t ctx) |> Bindlib.box_list in
|
||||
default_mark @@ EMatch (arg, cases, n)
|
||||
| EArray args ->
|
||||
let+ args = args |> List.map (t ctx) |> Bindlib.box_list in
|
||||
default_mark @@ EArray args
|
||||
| EAbs ((binder, pos_binder), ts) ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let body = t ctx body in
|
||||
let+ binder = Bindlib.bind_mvar vars body in
|
||||
default_mark @@ EAbs ((binder, pos_binder), ts)
|
||||
| EApp (e1, args) ->
|
||||
let+ e1 = t ctx e1 and+ args = args |> List.map (t ctx) |> Bindlib.box_list in
|
||||
default_mark @@ EApp (e1, args)
|
||||
| EAssert e1 ->
|
||||
let+ e1 = t ctx e1 in
|
||||
default_mark @@ EAssert e1
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
let+ e1 = t ctx e1 and+ e2 = t ctx e2 and+ e3 = t ctx e3 in
|
||||
default_mark @@ EIfThenElse (e1, e2, e3)
|
||||
| ErrorOnEmpty e1 ->
|
||||
let+ e1 = t ctx e1 in
|
||||
default_mark @@ ErrorOnEmpty e1
|
||||
| EDefault (exceptions, just, cons) ->
|
||||
let+ exceptions = exceptions |> List.map (t ctx) |> Bindlib.box_list
|
||||
and+ just = t ctx just
|
||||
and+ cons = t ctx cons in
|
||||
default_mark @@ EDefault (exceptions, just, cons)
|
||||
| ELit _ | EOp _ -> Bindlib.box e
|
||||
|
||||
let rec iota_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
|
||||
let default_mark e' = Pos.same_pos_as e' e in
|
||||
match Pos.unmark e with
|
||||
| EMatch ((EInj (e1, i, n', _ts), _), cases, n) when Ast.EnumName.compare n n' = 0 ->
|
||||
let+ e1 = visitor_map iota_expr () e1
|
||||
and+ case = visitor_map iota_expr () (List.nth cases i) in
|
||||
default_mark @@ EApp (case, [ e1 ])
|
||||
| EMatch (e', cases, n)
|
||||
when begin
|
||||
cases
|
||||
|> List.mapi (fun i (case, _pos) ->
|
||||
match case with
|
||||
| EInj (_ei, i', n', _ts') -> i = i' && (* n = n' *) Ast.EnumName.compare n n' = 0
|
||||
| _ -> false)
|
||||
|> List.for_all Fun.id
|
||||
end ->
|
||||
visitor_map iota_expr () e'
|
||||
| _ -> visitor_map iota_expr () e
|
||||
|
||||
let rec beta_expr (_ : unit) (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
|
||||
let default_mark e' = Pos.same_pos_as e' e in
|
||||
match Pos.unmark e with
|
||||
| EApp (e1, args) -> (
|
||||
let+ e1 = visitor_map beta_expr () e1
|
||||
and+ args = List.map (visitor_map beta_expr ()) args |> Bindlib.box_list in
|
||||
match Pos.unmark e1 with
|
||||
| EAbs ((binder, _pos_binder), _ts) ->
|
||||
let (_ : (_, _) Bindlib.mbinder) = binder in
|
||||
Bindlib.msubst binder (List.map fst args |> Array.of_list)
|
||||
| _ -> default_mark @@ EApp (e1, args))
|
||||
| _ -> visitor_map beta_expr () e
|
||||
|
||||
(**TODO: refactor this using plain recursion because this new visitor paradigm does not perform the
|
||||
optimizations after that the children of an AST node have been optimized, see for instance
|
||||
[(false
|
||||
|| false) || e1]. *)
|
||||
let rec peephole_expr (ctx : decl_ctx) (e : expr Pos.marked) : expr Pos.marked Bindlib.box =
|
||||
let default_mark e' = Pos.same_pos_as e' e in
|
||||
match Pos.unmark e with
|
||||
| EIfThenElse (e1, e2, e3) -> (
|
||||
let+ new_e1 = visitor_map peephole_expr ctx e1
|
||||
and+ new_e2 = visitor_map peephole_expr ctx e2
|
||||
and+ new_e3 = visitor_map peephole_expr ctx e3 in
|
||||
match (Pos.unmark new_e1, Pos.unmark new_e2, Pos.unmark new_e3) with
|
||||
| ELit (LBool true), _, _ | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]), _, _
|
||||
->
|
||||
new_e2
|
||||
| ELit (LBool false), _, _ | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]), _, _
|
||||
->
|
||||
new_e3
|
||||
| ( _,
|
||||
(ELit (LBool true) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ])),
|
||||
(ELit (LBool false) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ])) ) ->
|
||||
e1
|
||||
| _ -> default_mark @@ EIfThenElse (new_e1, new_e2, new_e3))
|
||||
| EApp
|
||||
( ((EOp (Binop Or), _ | EApp ((EOp (Unop (Log _)), _), [ (EOp (Binop Or), _) ]), _) as op),
|
||||
[ e1; e2 ] ) -> (
|
||||
let+ new_e1 = visitor_map peephole_expr ctx e1
|
||||
and+ new_e2 = visitor_map peephole_expr ctx e2 in
|
||||
match (Pos.unmark new_e1, Pos.unmark new_e2) with
|
||||
| ELit (LBool false), new_e1 | new_e1, ELit (LBool false) -> default_mark @@ new_e1
|
||||
| ELit (LBool true), _ | _, ELit (LBool true) -> default_mark @@ ELit (LBool true)
|
||||
| _ -> default_mark @@ EApp (op, [ new_e1; new_e2 ]))
|
||||
| EApp (((EOp (Binop And), _) as op), [ e1; e2 ]) -> (
|
||||
let+ new_e1 = visitor_map peephole_expr ctx e1
|
||||
and+ new_e2 = visitor_map peephole_expr ctx e2 in
|
||||
match (new_e1, new_e2) with
|
||||
| (ELit (LBool true), _), new_e1 | new_e1, (ELit (LBool true), _) -> new_e1
|
||||
| (ELit (LBool false), _), _ | _, (ELit (LBool false), _) ->
|
||||
default_mark @@ ELit (LBool false)
|
||||
| _ -> default_mark @@ EApp (op, [ new_e1; new_e2 ]))
|
||||
| _ -> visitor_map peephole_expr ctx e
|
||||
[ e1; e2 ] ) ->
|
||||
(* reduction of logical or *)
|
||||
(Bindlib.box_apply2 (fun e1 e2 ->
|
||||
match (e1, e2) with
|
||||
| (ELit (LBool false), _), new_e | new_e, (ELit (LBool false), _) -> new_e
|
||||
| (ELit (LBool true), _), _ | _, (ELit (LBool true), _) -> (ELit (LBool true), pos)
|
||||
| _ -> (EApp (op, [ e1; e2 ]), pos)))
|
||||
(rec_helper e1) (rec_helper e2)
|
||||
| EApp
|
||||
( ((EOp (Binop And), _ | EApp ((EOp (Unop (Log _)), _), [ (EOp (Binop And), _) ]), _) as op),
|
||||
[ e1; e2 ] ) ->
|
||||
(* reduction of logical and *)
|
||||
(Bindlib.box_apply2 (fun e1 e2 ->
|
||||
match (e1, e2) with
|
||||
| (ELit (LBool true), _), new_e | new_e, (ELit (LBool true), _) -> new_e
|
||||
| (ELit (LBool false), _), _ | _, (ELit (LBool false), _) -> (ELit (LBool false), pos)
|
||||
| _ -> (EApp (op, [ e1; e2 ]), pos)))
|
||||
(rec_helper e1) (rec_helper e2)
|
||||
| EVar (x, _) -> Bindlib.box_apply (fun x -> (x, pos)) (Bindlib.box_var x)
|
||||
| ETuple (args, s_name) ->
|
||||
Bindlib.box_apply
|
||||
(fun args -> (ETuple (args, s_name), pos))
|
||||
(List.map rec_helper args |> Bindlib.box_list)
|
||||
| ETupleAccess (arg, i, s_name, typs) ->
|
||||
Bindlib.box_apply (fun arg -> (ETupleAccess (arg, i, s_name, typs), pos)) (rec_helper arg)
|
||||
| EInj (arg, i, e_name, typs) ->
|
||||
Bindlib.box_apply (fun arg -> (EInj (arg, i, e_name, typs), pos)) (rec_helper arg)
|
||||
| EMatch (arg, arms, e_name) ->
|
||||
Bindlib.box_apply2
|
||||
(fun arg arms ->
|
||||
match (arg, arms) with
|
||||
| (EInj (e1, i, e_name', _ts), _), _ when Ast.EnumName.compare e_name e_name' = 0 ->
|
||||
(* iota reduction *)
|
||||
(EApp (List.nth arms i, [ e1 ]), pos)
|
||||
| _ -> (EMatch (arg, arms, e_name), pos))
|
||||
(rec_helper arg)
|
||||
(List.map rec_helper arms |> Bindlib.box_list)
|
||||
| EArray args ->
|
||||
Bindlib.box_apply
|
||||
(fun args -> (EArray args, pos))
|
||||
(List.map rec_helper args |> Bindlib.box_list)
|
||||
| ELit l -> Bindlib.box (ELit l, pos)
|
||||
| EAbs ((binder, binder_pos), typs) ->
|
||||
let vars, body = Bindlib.unmbind binder in
|
||||
let new_body = partial_evaluation ctx body in
|
||||
let new_binder = Bindlib.bind_mvar vars new_body in
|
||||
Bindlib.box_apply (fun binder -> (EAbs ((binder, binder_pos), typs), pos)) new_binder
|
||||
| EApp (f, args) ->
|
||||
Bindlib.box_apply2
|
||||
(fun f args ->
|
||||
match Pos.unmark f with
|
||||
| EAbs ((binder, _pos_binder), _ts) ->
|
||||
(* beta reduction *)
|
||||
Bindlib.msubst binder (List.map fst args |> Array.of_list)
|
||||
| _ -> (EApp (f, args), pos))
|
||||
(rec_helper f)
|
||||
(List.map rec_helper args |> Bindlib.box_list)
|
||||
| EAssert e1 -> Bindlib.box_apply (fun e1 -> (EAssert e1, pos)) (rec_helper e1)
|
||||
| EOp op -> Bindlib.box (EOp op, pos)
|
||||
| EDefault (exceptions, just, cons) ->
|
||||
Bindlib.box_apply3
|
||||
(fun exceptions just cons -> (EDefault (exceptions, just, cons), pos))
|
||||
(List.map rec_helper exceptions |> Bindlib.box_list)
|
||||
(rec_helper just) (rec_helper cons)
|
||||
| EIfThenElse (e1, e2, e3) ->
|
||||
Bindlib.box_apply3
|
||||
(fun e1 e2 e3 ->
|
||||
match (Pos.unmark e1, Pos.unmark e2, Pos.unmark e3) with
|
||||
| ELit (LBool true), _, _
|
||||
| EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ]), _, _ ->
|
||||
e2
|
||||
| ELit (LBool false), _, _
|
||||
| EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ]), _, _ ->
|
||||
e3
|
||||
| ( _,
|
||||
(ELit (LBool true) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool true), _) ])),
|
||||
(ELit (LBool false) | EApp ((EOp (Unop (Log _)), _), [ (ELit (LBool false), _) ])) )
|
||||
->
|
||||
e1
|
||||
| _ -> (EIfThenElse (e1, e2, e3), pos))
|
||||
(rec_helper e1) (rec_helper e2) (rec_helper e3)
|
||||
| ErrorOnEmpty e1 -> Bindlib.box_apply (fun e1 -> (ErrorOnEmpty e1, pos)) (rec_helper e1)
|
||||
|
||||
let optimize_expr (ctx : decl_ctx) (e : expr Pos.marked) : expr Pos.marked =
|
||||
let e = ref e in
|
||||
let continue = ref true in
|
||||
while !continue do
|
||||
let new_e =
|
||||
!e |> peephole_expr ctx |> Bindlib.unbox |> beta_expr () |> Bindlib.unbox |> iota_expr ()
|
||||
|> Bindlib.unbox
|
||||
in
|
||||
if not (expr_size new_e < expr_size !e) then continue := false;
|
||||
e := new_e
|
||||
done;
|
||||
!e
|
||||
let optimize_expr = partial_evaluation VarMap.empty
|
||||
|
||||
let program_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx : 'a) (p : program)
|
||||
: program =
|
||||
@ -178,10 +133,4 @@ let program_map (t : 'a -> expr Pos.marked -> expr Pos.marked Bindlib.box) (ctx
|
||||
p.scopes;
|
||||
}
|
||||
|
||||
let iota_optimizations (p : program) : program = program_map iota_expr () p
|
||||
|
||||
let beta_optimizations (p : program) : program = program_map beta_expr () p
|
||||
|
||||
let peephole_optimizations (p : program) : program = program_map peephole_expr p.decl_ctx p
|
||||
|
||||
let optimize_program (p : program) : program = p |> iota_optimizations |> peephole_optimizations
|
||||
let optimize_program (p : program) : program = program_map partial_evaluation VarMap.empty p
|
||||
|
@ -111,7 +111,7 @@ let generate_verification_conditions (p : program) : expr Pos.marked list =
|
||||
let e = Bindlib.unbox s_let.scope_let_expr in
|
||||
let vc = generate_vc_must_not_return_empty ctx e in
|
||||
let vc =
|
||||
if !Cli.optimize_flag then Optimizations.optimize_expr p.decl_ctx vc else vc
|
||||
if !Cli.optimize_flag then Bindlib.unbox (Optimizations.optimize_expr vc) else vc
|
||||
in
|
||||
(* TODO: drop logs for Aymeric *)
|
||||
(Pos.same_pos_as (Pos.unmark vc) e :: acc, ctx)
|
||||
|
@ -1077,12 +1077,8 @@ let enfant_le_plus_age (enfant_le_plus_age_in : enfant_le_plus_age_in) =
|
||||
}
|
||||
true)
|
||||
(fun (_ : _) ->
|
||||
let predicate_ : _ =
|
||||
fun (potentiel_plus_age_ : _) -> potentiel_plus_age_.age
|
||||
in
|
||||
Array.fold_left
|
||||
(fun (acc_ : _) (item_ : _) ->
|
||||
if predicate_ acc_ >! predicate_ item_ then acc_ else item_)
|
||||
(fun (acc_ : _) (item_ : _) -> if acc_.age >! item_.age then acc_ else item_)
|
||||
{
|
||||
identifiant = ~-!(integer_of_string "1");
|
||||
obligation_scolaire = Pendant ();
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user