mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Generalise optimisation of nested matches
This commit is contained in:
parent
78eaa16435
commit
3cbfa5f258
@ -19,43 +19,6 @@ open Definitions
|
|||||||
|
|
||||||
type ('a, 'b, 'm) optimizations_ctx = { decl_ctx : decl_ctx }
|
type ('a, 'b, 'm) optimizations_ctx = { decl_ctx : decl_ctx }
|
||||||
|
|
||||||
let all_match_cases_are_id_fun cases n =
|
|
||||||
EnumConstructor.Map.for_all
|
|
||||||
(fun i case ->
|
|
||||||
match Mark.remove case with
|
|
||||||
| EAbs { binder; _ } -> (
|
|
||||||
let var, body = Bindlib.unmbind binder in
|
|
||||||
(* because of invariant [invariant_match], the arity is always one. *)
|
|
||||||
let[@warning "-8"] [| var |] = var in
|
|
||||||
match Mark.remove body with
|
|
||||||
| EInj { cons = i'; name = n'; e = EVar x, _ } ->
|
|
||||||
EnumConstructor.equal i i'
|
|
||||||
&& EnumName.equal n n'
|
|
||||||
&& Bindlib.eq_vars x var
|
|
||||||
| EInj { cons = i'; name = n'; e = ELit LUnit, _ } ->
|
|
||||||
(* since unit is the only value of type unit. We don't need to check
|
|
||||||
the equality. *)
|
|
||||||
EnumConstructor.equal i i' && EnumName.equal n n'
|
|
||||||
| _ -> false)
|
|
||||||
| _ ->
|
|
||||||
(* because of invariant [invariant_match], there is always some EAbs in
|
|
||||||
each cases. *)
|
|
||||||
assert false)
|
|
||||||
cases
|
|
||||||
|
|
||||||
let all_match_cases_map_to_same_constructor cases n =
|
|
||||||
EnumConstructor.Map.for_all
|
|
||||||
(fun i case ->
|
|
||||||
match Mark.remove case with
|
|
||||||
| EAbs { binder; _ } -> (
|
|
||||||
let _, body = Bindlib.unmbind binder in
|
|
||||||
match Mark.remove body with
|
|
||||||
| EInj { cons = i'; name = n'; _ } ->
|
|
||||||
EnumConstructor.equal i i' && EnumName.equal n n'
|
|
||||||
| _ -> false)
|
|
||||||
| _ -> assert false)
|
|
||||||
cases
|
|
||||||
|
|
||||||
let binder_vars_used_at_most_once
|
let binder_vars_used_at_most_once
|
||||||
(binder :
|
(binder :
|
||||||
( ('a dcalc_lcalc, 'a dcalc_lcalc, 'm) base_gexpr,
|
( ('a dcalc_lcalc, 'a dcalc_lcalc, 'm) base_gexpr,
|
||||||
@ -79,6 +42,113 @@ let binder_vars_used_at_most_once
|
|||||||
in
|
in
|
||||||
not (Array.exists (fun c -> c > 1) (vars_count body))
|
not (Array.exists (fun c -> c > 1) (vars_count body))
|
||||||
|
|
||||||
|
(* beta reduction when variables not used, and for variable aliases and
|
||||||
|
literal *)
|
||||||
|
let simplified_apply f args tys =
|
||||||
|
match f with
|
||||||
|
| EAbs { binder; _ }, _
|
||||||
|
when binder_vars_used_at_most_once binder
|
||||||
|
|| List.for_all
|
||||||
|
(function (EVar _ | ELit _), _ -> true | _ -> false)
|
||||||
|
args ->
|
||||||
|
Mark.remove (Bindlib.msubst binder (List.map fst args |> Array.of_list))
|
||||||
|
| _ -> EApp { f; args; tys }
|
||||||
|
|
||||||
|
let literal_bool = function
|
||||||
|
| ELit (LBool b), _
|
||||||
|
| EAppOp { op = Log _, _; args = [(ELit (LBool b), _)]; _ }, _ ->
|
||||||
|
Some b
|
||||||
|
| _ -> None
|
||||||
|
|
||||||
|
let simplified_ifthenelse cond etrue efalse m =
|
||||||
|
if Expr.equal etrue efalse then Mark.remove etrue
|
||||||
|
else
|
||||||
|
match literal_bool etrue, literal_bool efalse with
|
||||||
|
| Some true, Some false -> Mark.remove cond
|
||||||
|
| Some false, Some true ->
|
||||||
|
EAppOp
|
||||||
|
{
|
||||||
|
op = Not, Expr.mark_pos m;
|
||||||
|
tys = [TLit TBool, Expr.mark_pos m];
|
||||||
|
args = [cond];
|
||||||
|
}
|
||||||
|
| Some true, Some true | Some false, Some false -> Mark.remove etrue
|
||||||
|
| _ -> (
|
||||||
|
match literal_bool cond with
|
||||||
|
| Some true -> Mark.remove etrue
|
||||||
|
| Some false -> Mark.remove efalse
|
||||||
|
| None -> EIfThenElse { cond; etrue; efalse })
|
||||||
|
|
||||||
|
(* builds a [EMatch] term, flattening nested matches/if-then-else: the matching
|
||||||
|
arg branching are explored, and if they all lead to enum constructor
|
||||||
|
literals, the surrounding match cases are inlined. Code duplication is
|
||||||
|
detected and aborts the inlining. *)
|
||||||
|
let simplified_match enum_name match_arg cases mark =
|
||||||
|
let max_duplicate_inlining_size = 3 in
|
||||||
|
let allow_duplicate_inlining_cases =
|
||||||
|
EnumConstructor.Map.fold
|
||||||
|
(fun cons f acc ->
|
||||||
|
if Expr.size f <= max_duplicate_inlining_size then
|
||||||
|
EnumConstructor.Set.add cons acc
|
||||||
|
else acc)
|
||||||
|
cases EnumConstructor.Set.empty
|
||||||
|
in
|
||||||
|
let app_cases cons e =
|
||||||
|
simplified_apply
|
||||||
|
(EnumConstructor.Map.find cons cases)
|
||||||
|
[e]
|
||||||
|
[Expr.maybe_ty (Mark.get e)]
|
||||||
|
in
|
||||||
|
let ret_ty = Expr.maybe_ty mark in
|
||||||
|
let rec aux seen_constrs = function
|
||||||
|
| EInj { cons; e; _ }, m ->
|
||||||
|
if EnumConstructor.Set.mem cons seen_constrs then raise Exit;
|
||||||
|
(* Abort inlining to avoid code duplication *)
|
||||||
|
let seen_constrs =
|
||||||
|
if EnumConstructor.Set.mem cons allow_duplicate_inlining_cases then
|
||||||
|
seen_constrs
|
||||||
|
else EnumConstructor.Set.add cons seen_constrs
|
||||||
|
in
|
||||||
|
seen_constrs, (app_cases cons e, Expr.with_ty m ret_ty)
|
||||||
|
| EMatch ({ cases; _ } as ematch), m ->
|
||||||
|
let seen_constrs, cases =
|
||||||
|
EnumConstructor.Map.fold
|
||||||
|
(fun cons case (seen_constrs, acc) ->
|
||||||
|
match case with
|
||||||
|
| EAbs ({ binder; _ } as eabs), m ->
|
||||||
|
let vars, body = Bindlib.unmbind binder in
|
||||||
|
let seen_constrs, body = aux seen_constrs body in
|
||||||
|
let binder = Bindlib.unbox (Expr.bind vars (Expr.rebox body)) in
|
||||||
|
let m =
|
||||||
|
Expr.map_ty
|
||||||
|
(function
|
||||||
|
| TArrow (args, _), pos -> TArrow (args, ret_ty), pos
|
||||||
|
| (TAny, _) as t -> t
|
||||||
|
| _ -> assert false)
|
||||||
|
m
|
||||||
|
in
|
||||||
|
( seen_constrs,
|
||||||
|
EnumConstructor.Map.add cons (EAbs { eabs with binder }, m) acc
|
||||||
|
)
|
||||||
|
| _ -> assert false)
|
||||||
|
cases
|
||||||
|
(seen_constrs, EnumConstructor.Map.empty)
|
||||||
|
in
|
||||||
|
seen_constrs, (EMatch { ematch with cases }, Expr.with_ty m ret_ty)
|
||||||
|
| EIfThenElse { cond; etrue; efalse }, m ->
|
||||||
|
let seen_constrs, etrue = aux seen_constrs etrue in
|
||||||
|
let seen_constrs, efalse = aux seen_constrs efalse in
|
||||||
|
let mark = Expr.with_ty m ret_ty in
|
||||||
|
seen_constrs, (simplified_ifthenelse cond etrue efalse mark, mark)
|
||||||
|
| _ -> raise Exit
|
||||||
|
in
|
||||||
|
try
|
||||||
|
let _seen_contrs, e = aux EnumConstructor.Set.empty match_arg in
|
||||||
|
Mark.remove e
|
||||||
|
with Exit ->
|
||||||
|
(* Optimisation was aborted due a non-terminal or code duplication *)
|
||||||
|
EMatch { e = match_arg; cases; name = enum_name }
|
||||||
|
|
||||||
let rec optimize_expr :
|
let rec optimize_expr :
|
||||||
type a b.
|
type a b.
|
||||||
(a, b, 'm) optimizations_ctx ->
|
(a, b, 'm) optimizations_ctx ->
|
||||||
@ -108,59 +178,8 @@ let rec optimize_expr :
|
|||||||
| EAppOp { op = And, _; args = [(e, _); (ELit (LBool b), _)]; _ } ->
|
| EAppOp { op = And, _; args = [(e, _); (ELit (LBool b), _)]; _ } ->
|
||||||
(* reduction of logical and *)
|
(* reduction of logical and *)
|
||||||
if b then e else ELit (LBool false)
|
if b then e else ELit (LBool false)
|
||||||
| EMatch { e = EInj { e = e'; cons; name = n' }, _; cases; name = n }
|
| EMatch { name; e; cases } -> simplified_match name e cases mark
|
||||||
(* iota-reduction *)
|
| EApp { f; args; tys } -> simplified_apply f args tys
|
||||||
when EnumName.equal n n' -> (
|
|
||||||
(* match E x with | E y -> e1 = e1[y |-> x]*)
|
|
||||||
match Mark.remove @@ EnumConstructor.Map.find cons cases with
|
|
||||||
(* holds because of invariant_match_inversion *)
|
|
||||||
| EAbs { binder; _ } ->
|
|
||||||
Mark.remove
|
|
||||||
(Bindlib.msubst binder ([e'] |> List.map fst |> Array.of_list))
|
|
||||||
| _ -> assert false)
|
|
||||||
| EMatch { e = e'; cases; name = n } when all_match_cases_are_id_fun cases n
|
|
||||||
->
|
|
||||||
(* iota-reduction when the match is equivalent to an identity function *)
|
|
||||||
Mark.remove e'
|
|
||||||
| EMatch
|
|
||||||
{
|
|
||||||
e = EMatch { e = arg; cases = cases1; name = n1 }, _;
|
|
||||||
cases = cases2;
|
|
||||||
name = n2;
|
|
||||||
}
|
|
||||||
when EnumName.equal n1 n2
|
|
||||||
&& all_match_cases_map_to_same_constructor cases1 n1 ->
|
|
||||||
(* iota-reduction when the matched expression is itself a match of the
|
|
||||||
same enum mapping all constructors to themselves *)
|
|
||||||
let cases =
|
|
||||||
EnumConstructor.Map.merge
|
|
||||||
(fun _i o1 o2 ->
|
|
||||||
match o1, o2 with
|
|
||||||
| Some b1, Some e2 -> (
|
|
||||||
match Mark.remove b1, Mark.remove e2 with
|
|
||||||
| EAbs { binder = b1; _ }, EAbs { binder = b2; tys } -> (
|
|
||||||
let v1, e1 = Bindlib.unmbind b1 in
|
|
||||||
match Mark.remove e1 with
|
|
||||||
| EInj { e = e1, _; _ } ->
|
|
||||||
Some
|
|
||||||
(Expr.unbox
|
|
||||||
(Expr.make_abs v1
|
|
||||||
(Expr.rebox (Bindlib.msubst b2 [| e1 |]))
|
|
||||||
tys (Expr.pos e2)))
|
|
||||||
| _ -> assert false)
|
|
||||||
| _ -> assert false)
|
|
||||||
| _ -> assert false)
|
|
||||||
cases1 cases2
|
|
||||||
in
|
|
||||||
EMatch { e = arg; cases; name = n1 }
|
|
||||||
| EApp { f = EAbs { binder; _ }, _; args; _ }
|
|
||||||
when binder_vars_used_at_most_once binder
|
|
||||||
|| List.for_all
|
|
||||||
(function (EVar _ | ELit _), _ -> true | _ -> false)
|
|
||||||
args ->
|
|
||||||
(* beta reduction when variables not used, and for variable aliases and
|
|
||||||
literal *)
|
|
||||||
Mark.remove (Bindlib.msubst binder (List.map fst args |> Array.of_list))
|
|
||||||
| EStructAccess { name; field; e = EStruct { name = name1; fields }, _ }
|
| EStructAccess { name; field; e = EStruct { name = name1; fields }, _ }
|
||||||
when StructName.equal name name1 ->
|
when StructName.equal name name1 ->
|
||||||
Mark.remove (StructField.Map.find field fields)
|
Mark.remove (StructField.Map.find field fields)
|
||||||
@ -193,12 +212,7 @@ let rec optimize_expr :
|
|||||||
->
|
->
|
||||||
(* No exceptions with condition [true] *)
|
(* No exceptions with condition [true] *)
|
||||||
Mark.remove cons
|
Mark.remove cons
|
||||||
| ( [],
|
| [], cond -> simplified_ifthenelse cond cons (EEmpty, mark) mark
|
||||||
( ( ELit (LBool false)
|
|
||||||
| EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ),
|
|
||||||
_ ) ) ->
|
|
||||||
(* No exceptions and condition false *)
|
|
||||||
EEmpty
|
|
||||||
| ( [except],
|
| ( [except],
|
||||||
( ( ELit (LBool false)
|
( ( ELit (LBool false)
|
||||||
| EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ),
|
| EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ),
|
||||||
@ -206,49 +220,8 @@ let rec optimize_expr :
|
|||||||
(* Single exception and condition false *)
|
(* Single exception and condition false *)
|
||||||
Mark.remove except
|
Mark.remove except
|
||||||
| excepts, just -> EDefault { excepts; just; cons })
|
| excepts, just -> EDefault { excepts; just; cons })
|
||||||
| EIfThenElse
|
| EIfThenElse { cond; etrue; efalse } ->
|
||||||
{
|
simplified_ifthenelse cond etrue efalse mark
|
||||||
cond =
|
|
||||||
( ELit (LBool true), _
|
|
||||||
| EAppOp { op = Log _, _; args = [(ELit (LBool true), _)]; _ }, _ );
|
|
||||||
etrue;
|
|
||||||
_;
|
|
||||||
} ->
|
|
||||||
Mark.remove etrue
|
|
||||||
| EIfThenElse
|
|
||||||
{
|
|
||||||
cond =
|
|
||||||
( ( ELit (LBool false)
|
|
||||||
| EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ),
|
|
||||||
_ );
|
|
||||||
efalse;
|
|
||||||
_;
|
|
||||||
} ->
|
|
||||||
Mark.remove efalse
|
|
||||||
| EIfThenElse
|
|
||||||
{
|
|
||||||
cond;
|
|
||||||
etrue =
|
|
||||||
( ( ELit (LBool btrue)
|
|
||||||
| EAppOp { op = Log _, _; args = [(ELit (LBool btrue), _)]; _ } ),
|
|
||||||
_ );
|
|
||||||
efalse =
|
|
||||||
( ( ELit (LBool bfalse)
|
|
||||||
| EAppOp { op = Log _, _; args = [(ELit (LBool bfalse), _)]; _ }
|
|
||||||
),
|
|
||||||
_ );
|
|
||||||
} ->
|
|
||||||
if btrue && not bfalse then Mark.remove cond
|
|
||||||
else if (not btrue) && bfalse then
|
|
||||||
EAppOp
|
|
||||||
{
|
|
||||||
op = Not, Expr.mark_pos mark;
|
|
||||||
tys = [TLit TBool, Expr.mark_pos mark];
|
|
||||||
args = [cond];
|
|
||||||
}
|
|
||||||
(* note: this last call eliminates the condition & might skip log calls
|
|
||||||
as well *)
|
|
||||||
else (* btrue = bfalse *) ELit (LBool btrue)
|
|
||||||
| EAppOp { op = Op.Fold, _; args = [_f; init; (EArray [], _)]; _ } ->
|
| EAppOp { op = Op.Fold, _; args = [_f; init; (EArray [], _)]; _ } ->
|
||||||
(*reduces a fold with an empty list *)
|
(*reduces a fold with an empty list *)
|
||||||
Mark.remove init
|
Mark.remove init
|
||||||
|
@ -131,13 +131,9 @@ let scope Foo
|
|||||||
Foo_in.b_in
|
Foo_in.b_in
|
||||||
in
|
in
|
||||||
let set b : bool =
|
let set b : bool =
|
||||||
match
|
match (handle_exceptions [b.0 b.1 ()]) with
|
||||||
(match (handle_exceptions [b.0 b.1 ()]) with
|
| ENone → true
|
||||||
| ENone → ESome true
|
| ESome x → x
|
||||||
| ESome x → ESome x)
|
|
||||||
with
|
|
||||||
| ENone → error NoValue
|
|
||||||
| ESome arg → arg
|
|
||||||
in
|
in
|
||||||
let set r :
|
let set r :
|
||||||
Result {
|
Result {
|
||||||
|
Loading…
Reference in New Issue
Block a user