Generalise optimisation of nested matches

This commit is contained in:
Louis Gesbert 2024-08-27 16:56:38 +02:00
parent 78eaa16435
commit 3cbfa5f258
2 changed files with 115 additions and 146 deletions

View File

@ -19,43 +19,6 @@ open Definitions
type ('a, 'b, 'm) optimizations_ctx = { decl_ctx : decl_ctx } type ('a, 'b, 'm) optimizations_ctx = { decl_ctx : decl_ctx }
let all_match_cases_are_id_fun cases n =
EnumConstructor.Map.for_all
(fun i case ->
match Mark.remove case with
| EAbs { binder; _ } -> (
let var, body = Bindlib.unmbind binder in
(* because of invariant [invariant_match], the arity is always one. *)
let[@warning "-8"] [| var |] = var in
match Mark.remove body with
| EInj { cons = i'; name = n'; e = EVar x, _ } ->
EnumConstructor.equal i i'
&& EnumName.equal n n'
&& Bindlib.eq_vars x var
| EInj { cons = i'; name = n'; e = ELit LUnit, _ } ->
(* since unit is the only value of type unit. We don't need to check
the equality. *)
EnumConstructor.equal i i' && EnumName.equal n n'
| _ -> false)
| _ ->
(* because of invariant [invariant_match], there is always some EAbs in
each cases. *)
assert false)
cases
let all_match_cases_map_to_same_constructor cases n =
EnumConstructor.Map.for_all
(fun i case ->
match Mark.remove case with
| EAbs { binder; _ } -> (
let _, body = Bindlib.unmbind binder in
match Mark.remove body with
| EInj { cons = i'; name = n'; _ } ->
EnumConstructor.equal i i' && EnumName.equal n n'
| _ -> false)
| _ -> assert false)
cases
let binder_vars_used_at_most_once let binder_vars_used_at_most_once
(binder : (binder :
( ('a dcalc_lcalc, 'a dcalc_lcalc, 'm) base_gexpr, ( ('a dcalc_lcalc, 'a dcalc_lcalc, 'm) base_gexpr,
@ -79,6 +42,113 @@ let binder_vars_used_at_most_once
in in
not (Array.exists (fun c -> c > 1) (vars_count body)) not (Array.exists (fun c -> c > 1) (vars_count body))
(* beta reduction when variables not used, and for variable aliases and
literal *)
let simplified_apply f args tys =
match f with
| EAbs { binder; _ }, _
when binder_vars_used_at_most_once binder
|| List.for_all
(function (EVar _ | ELit _), _ -> true | _ -> false)
args ->
Mark.remove (Bindlib.msubst binder (List.map fst args |> Array.of_list))
| _ -> EApp { f; args; tys }
let literal_bool = function
| ELit (LBool b), _
| EAppOp { op = Log _, _; args = [(ELit (LBool b), _)]; _ }, _ ->
Some b
| _ -> None
let simplified_ifthenelse cond etrue efalse m =
if Expr.equal etrue efalse then Mark.remove etrue
else
match literal_bool etrue, literal_bool efalse with
| Some true, Some false -> Mark.remove cond
| Some false, Some true ->
EAppOp
{
op = Not, Expr.mark_pos m;
tys = [TLit TBool, Expr.mark_pos m];
args = [cond];
}
| Some true, Some true | Some false, Some false -> Mark.remove etrue
| _ -> (
match literal_bool cond with
| Some true -> Mark.remove etrue
| Some false -> Mark.remove efalse
| None -> EIfThenElse { cond; etrue; efalse })
(* builds a [EMatch] term, flattening nested matches/if-then-else: the matching
arg branching are explored, and if they all lead to enum constructor
literals, the surrounding match cases are inlined. Code duplication is
detected and aborts the inlining. *)
let simplified_match enum_name match_arg cases mark =
let max_duplicate_inlining_size = 3 in
let allow_duplicate_inlining_cases =
EnumConstructor.Map.fold
(fun cons f acc ->
if Expr.size f <= max_duplicate_inlining_size then
EnumConstructor.Set.add cons acc
else acc)
cases EnumConstructor.Set.empty
in
let app_cases cons e =
simplified_apply
(EnumConstructor.Map.find cons cases)
[e]
[Expr.maybe_ty (Mark.get e)]
in
let ret_ty = Expr.maybe_ty mark in
let rec aux seen_constrs = function
| EInj { cons; e; _ }, m ->
if EnumConstructor.Set.mem cons seen_constrs then raise Exit;
(* Abort inlining to avoid code duplication *)
let seen_constrs =
if EnumConstructor.Set.mem cons allow_duplicate_inlining_cases then
seen_constrs
else EnumConstructor.Set.add cons seen_constrs
in
seen_constrs, (app_cases cons e, Expr.with_ty m ret_ty)
| EMatch ({ cases; _ } as ematch), m ->
let seen_constrs, cases =
EnumConstructor.Map.fold
(fun cons case (seen_constrs, acc) ->
match case with
| EAbs ({ binder; _ } as eabs), m ->
let vars, body = Bindlib.unmbind binder in
let seen_constrs, body = aux seen_constrs body in
let binder = Bindlib.unbox (Expr.bind vars (Expr.rebox body)) in
let m =
Expr.map_ty
(function
| TArrow (args, _), pos -> TArrow (args, ret_ty), pos
| (TAny, _) as t -> t
| _ -> assert false)
m
in
( seen_constrs,
EnumConstructor.Map.add cons (EAbs { eabs with binder }, m) acc
)
| _ -> assert false)
cases
(seen_constrs, EnumConstructor.Map.empty)
in
seen_constrs, (EMatch { ematch with cases }, Expr.with_ty m ret_ty)
| EIfThenElse { cond; etrue; efalse }, m ->
let seen_constrs, etrue = aux seen_constrs etrue in
let seen_constrs, efalse = aux seen_constrs efalse in
let mark = Expr.with_ty m ret_ty in
seen_constrs, (simplified_ifthenelse cond etrue efalse mark, mark)
| _ -> raise Exit
in
try
let _seen_contrs, e = aux EnumConstructor.Set.empty match_arg in
Mark.remove e
with Exit ->
(* Optimisation was aborted due a non-terminal or code duplication *)
EMatch { e = match_arg; cases; name = enum_name }
let rec optimize_expr : let rec optimize_expr :
type a b. type a b.
(a, b, 'm) optimizations_ctx -> (a, b, 'm) optimizations_ctx ->
@ -108,59 +178,8 @@ let rec optimize_expr :
| EAppOp { op = And, _; args = [(e, _); (ELit (LBool b), _)]; _ } -> | EAppOp { op = And, _; args = [(e, _); (ELit (LBool b), _)]; _ } ->
(* reduction of logical and *) (* reduction of logical and *)
if b then e else ELit (LBool false) if b then e else ELit (LBool false)
| EMatch { e = EInj { e = e'; cons; name = n' }, _; cases; name = n } | EMatch { name; e; cases } -> simplified_match name e cases mark
(* iota-reduction *) | EApp { f; args; tys } -> simplified_apply f args tys
when EnumName.equal n n' -> (
(* match E x with | E y -> e1 = e1[y |-> x]*)
match Mark.remove @@ EnumConstructor.Map.find cons cases with
(* holds because of invariant_match_inversion *)
| EAbs { binder; _ } ->
Mark.remove
(Bindlib.msubst binder ([e'] |> List.map fst |> Array.of_list))
| _ -> assert false)
| EMatch { e = e'; cases; name = n } when all_match_cases_are_id_fun cases n
->
(* iota-reduction when the match is equivalent to an identity function *)
Mark.remove e'
| EMatch
{
e = EMatch { e = arg; cases = cases1; name = n1 }, _;
cases = cases2;
name = n2;
}
when EnumName.equal n1 n2
&& all_match_cases_map_to_same_constructor cases1 n1 ->
(* iota-reduction when the matched expression is itself a match of the
same enum mapping all constructors to themselves *)
let cases =
EnumConstructor.Map.merge
(fun _i o1 o2 ->
match o1, o2 with
| Some b1, Some e2 -> (
match Mark.remove b1, Mark.remove e2 with
| EAbs { binder = b1; _ }, EAbs { binder = b2; tys } -> (
let v1, e1 = Bindlib.unmbind b1 in
match Mark.remove e1 with
| EInj { e = e1, _; _ } ->
Some
(Expr.unbox
(Expr.make_abs v1
(Expr.rebox (Bindlib.msubst b2 [| e1 |]))
tys (Expr.pos e2)))
| _ -> assert false)
| _ -> assert false)
| _ -> assert false)
cases1 cases2
in
EMatch { e = arg; cases; name = n1 }
| EApp { f = EAbs { binder; _ }, _; args; _ }
when binder_vars_used_at_most_once binder
|| List.for_all
(function (EVar _ | ELit _), _ -> true | _ -> false)
args ->
(* beta reduction when variables not used, and for variable aliases and
literal *)
Mark.remove (Bindlib.msubst binder (List.map fst args |> Array.of_list))
| EStructAccess { name; field; e = EStruct { name = name1; fields }, _ } | EStructAccess { name; field; e = EStruct { name = name1; fields }, _ }
when StructName.equal name name1 -> when StructName.equal name name1 ->
Mark.remove (StructField.Map.find field fields) Mark.remove (StructField.Map.find field fields)
@ -193,12 +212,7 @@ let rec optimize_expr :
-> ->
(* No exceptions with condition [true] *) (* No exceptions with condition [true] *)
Mark.remove cons Mark.remove cons
| ( [], | [], cond -> simplified_ifthenelse cond cons (EEmpty, mark) mark
( ( ELit (LBool false)
| EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ),
_ ) ) ->
(* No exceptions and condition false *)
EEmpty
| ( [except], | ( [except],
( ( ELit (LBool false) ( ( ELit (LBool false)
| EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ), | EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ),
@ -206,49 +220,8 @@ let rec optimize_expr :
(* Single exception and condition false *) (* Single exception and condition false *)
Mark.remove except Mark.remove except
| excepts, just -> EDefault { excepts; just; cons }) | excepts, just -> EDefault { excepts; just; cons })
| EIfThenElse | EIfThenElse { cond; etrue; efalse } ->
{ simplified_ifthenelse cond etrue efalse mark
cond =
( ELit (LBool true), _
| EAppOp { op = Log _, _; args = [(ELit (LBool true), _)]; _ }, _ );
etrue;
_;
} ->
Mark.remove etrue
| EIfThenElse
{
cond =
( ( ELit (LBool false)
| EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ),
_ );
efalse;
_;
} ->
Mark.remove efalse
| EIfThenElse
{
cond;
etrue =
( ( ELit (LBool btrue)
| EAppOp { op = Log _, _; args = [(ELit (LBool btrue), _)]; _ } ),
_ );
efalse =
( ( ELit (LBool bfalse)
| EAppOp { op = Log _, _; args = [(ELit (LBool bfalse), _)]; _ }
),
_ );
} ->
if btrue && not bfalse then Mark.remove cond
else if (not btrue) && bfalse then
EAppOp
{
op = Not, Expr.mark_pos mark;
tys = [TLit TBool, Expr.mark_pos mark];
args = [cond];
}
(* note: this last call eliminates the condition & might skip log calls
as well *)
else (* btrue = bfalse *) ELit (LBool btrue)
| EAppOp { op = Op.Fold, _; args = [_f; init; (EArray [], _)]; _ } -> | EAppOp { op = Op.Fold, _; args = [_f; init; (EArray [], _)]; _ } ->
(*reduces a fold with an empty list *) (*reduces a fold with an empty list *)
Mark.remove init Mark.remove init

View File

@ -131,13 +131,9 @@ let scope Foo
Foo_in.b_in Foo_in.b_in
in in
let set b : bool = let set b : bool =
match match (handle_exceptions [b.0 b.1 ()]) with
(match (handle_exceptions [b.0 b.1 ()]) with | ENone → true
| ENone → ESome true | ESome x → x
| ESome x → ESome x)
with
| ENone → error NoValue
| ESome arg → arg
in in
let set r : let set r :
Result { Result {