From 3cbfa5f258930e5af943d48f6b0ffe4b31a7a898 Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Tue, 27 Aug 2024 16:56:38 +0200 Subject: [PATCH] Generalise optimisation of nested matches --- compiler/shared_ast/optimizations.ml | 251 ++++++++---------- .../scope_call_func_struct_closure.catala_en | 10 +- 2 files changed, 115 insertions(+), 146 deletions(-) diff --git a/compiler/shared_ast/optimizations.ml b/compiler/shared_ast/optimizations.ml index 3ad6f126..64774556 100644 --- a/compiler/shared_ast/optimizations.ml +++ b/compiler/shared_ast/optimizations.ml @@ -19,43 +19,6 @@ open Definitions type ('a, 'b, 'm) optimizations_ctx = { decl_ctx : decl_ctx } -let all_match_cases_are_id_fun cases n = - EnumConstructor.Map.for_all - (fun i case -> - match Mark.remove case with - | EAbs { binder; _ } -> ( - let var, body = Bindlib.unmbind binder in - (* because of invariant [invariant_match], the arity is always one. *) - let[@warning "-8"] [| var |] = var in - match Mark.remove body with - | EInj { cons = i'; name = n'; e = EVar x, _ } -> - EnumConstructor.equal i i' - && EnumName.equal n n' - && Bindlib.eq_vars x var - | EInj { cons = i'; name = n'; e = ELit LUnit, _ } -> - (* since unit is the only value of type unit. We don't need to check - the equality. *) - EnumConstructor.equal i i' && EnumName.equal n n' - | _ -> false) - | _ -> - (* because of invariant [invariant_match], there is always some EAbs in - each cases. *) - assert false) - cases - -let all_match_cases_map_to_same_constructor cases n = - EnumConstructor.Map.for_all - (fun i case -> - match Mark.remove case with - | EAbs { binder; _ } -> ( - let _, body = Bindlib.unmbind binder in - match Mark.remove body with - | EInj { cons = i'; name = n'; _ } -> - EnumConstructor.equal i i' && EnumName.equal n n' - | _ -> false) - | _ -> assert false) - cases - let binder_vars_used_at_most_once (binder : ( ('a dcalc_lcalc, 'a dcalc_lcalc, 'm) base_gexpr, @@ -79,6 +42,113 @@ let binder_vars_used_at_most_once in not (Array.exists (fun c -> c > 1) (vars_count body)) +(* beta reduction when variables not used, and for variable aliases and + literal *) +let simplified_apply f args tys = + match f with + | EAbs { binder; _ }, _ + when binder_vars_used_at_most_once binder + || List.for_all + (function (EVar _ | ELit _), _ -> true | _ -> false) + args -> + Mark.remove (Bindlib.msubst binder (List.map fst args |> Array.of_list)) + | _ -> EApp { f; args; tys } + +let literal_bool = function + | ELit (LBool b), _ + | EAppOp { op = Log _, _; args = [(ELit (LBool b), _)]; _ }, _ -> + Some b + | _ -> None + +let simplified_ifthenelse cond etrue efalse m = + if Expr.equal etrue efalse then Mark.remove etrue + else + match literal_bool etrue, literal_bool efalse with + | Some true, Some false -> Mark.remove cond + | Some false, Some true -> + EAppOp + { + op = Not, Expr.mark_pos m; + tys = [TLit TBool, Expr.mark_pos m]; + args = [cond]; + } + | Some true, Some true | Some false, Some false -> Mark.remove etrue + | _ -> ( + match literal_bool cond with + | Some true -> Mark.remove etrue + | Some false -> Mark.remove efalse + | None -> EIfThenElse { cond; etrue; efalse }) + +(* builds a [EMatch] term, flattening nested matches/if-then-else: the matching + arg branching are explored, and if they all lead to enum constructor + literals, the surrounding match cases are inlined. Code duplication is + detected and aborts the inlining. *) +let simplified_match enum_name match_arg cases mark = + let max_duplicate_inlining_size = 3 in + let allow_duplicate_inlining_cases = + EnumConstructor.Map.fold + (fun cons f acc -> + if Expr.size f <= max_duplicate_inlining_size then + EnumConstructor.Set.add cons acc + else acc) + cases EnumConstructor.Set.empty + in + let app_cases cons e = + simplified_apply + (EnumConstructor.Map.find cons cases) + [e] + [Expr.maybe_ty (Mark.get e)] + in + let ret_ty = Expr.maybe_ty mark in + let rec aux seen_constrs = function + | EInj { cons; e; _ }, m -> + if EnumConstructor.Set.mem cons seen_constrs then raise Exit; + (* Abort inlining to avoid code duplication *) + let seen_constrs = + if EnumConstructor.Set.mem cons allow_duplicate_inlining_cases then + seen_constrs + else EnumConstructor.Set.add cons seen_constrs + in + seen_constrs, (app_cases cons e, Expr.with_ty m ret_ty) + | EMatch ({ cases; _ } as ematch), m -> + let seen_constrs, cases = + EnumConstructor.Map.fold + (fun cons case (seen_constrs, acc) -> + match case with + | EAbs ({ binder; _ } as eabs), m -> + let vars, body = Bindlib.unmbind binder in + let seen_constrs, body = aux seen_constrs body in + let binder = Bindlib.unbox (Expr.bind vars (Expr.rebox body)) in + let m = + Expr.map_ty + (function + | TArrow (args, _), pos -> TArrow (args, ret_ty), pos + | (TAny, _) as t -> t + | _ -> assert false) + m + in + ( seen_constrs, + EnumConstructor.Map.add cons (EAbs { eabs with binder }, m) acc + ) + | _ -> assert false) + cases + (seen_constrs, EnumConstructor.Map.empty) + in + seen_constrs, (EMatch { ematch with cases }, Expr.with_ty m ret_ty) + | EIfThenElse { cond; etrue; efalse }, m -> + let seen_constrs, etrue = aux seen_constrs etrue in + let seen_constrs, efalse = aux seen_constrs efalse in + let mark = Expr.with_ty m ret_ty in + seen_constrs, (simplified_ifthenelse cond etrue efalse mark, mark) + | _ -> raise Exit + in + try + let _seen_contrs, e = aux EnumConstructor.Set.empty match_arg in + Mark.remove e + with Exit -> + (* Optimisation was aborted due a non-terminal or code duplication *) + EMatch { e = match_arg; cases; name = enum_name } + let rec optimize_expr : type a b. (a, b, 'm) optimizations_ctx -> @@ -108,59 +178,8 @@ let rec optimize_expr : | EAppOp { op = And, _; args = [(e, _); (ELit (LBool b), _)]; _ } -> (* reduction of logical and *) if b then e else ELit (LBool false) - | EMatch { e = EInj { e = e'; cons; name = n' }, _; cases; name = n } - (* iota-reduction *) - when EnumName.equal n n' -> ( - (* match E x with | E y -> e1 = e1[y |-> x]*) - match Mark.remove @@ EnumConstructor.Map.find cons cases with - (* holds because of invariant_match_inversion *) - | EAbs { binder; _ } -> - Mark.remove - (Bindlib.msubst binder ([e'] |> List.map fst |> Array.of_list)) - | _ -> assert false) - | EMatch { e = e'; cases; name = n } when all_match_cases_are_id_fun cases n - -> - (* iota-reduction when the match is equivalent to an identity function *) - Mark.remove e' - | EMatch - { - e = EMatch { e = arg; cases = cases1; name = n1 }, _; - cases = cases2; - name = n2; - } - when EnumName.equal n1 n2 - && all_match_cases_map_to_same_constructor cases1 n1 -> - (* iota-reduction when the matched expression is itself a match of the - same enum mapping all constructors to themselves *) - let cases = - EnumConstructor.Map.merge - (fun _i o1 o2 -> - match o1, o2 with - | Some b1, Some e2 -> ( - match Mark.remove b1, Mark.remove e2 with - | EAbs { binder = b1; _ }, EAbs { binder = b2; tys } -> ( - let v1, e1 = Bindlib.unmbind b1 in - match Mark.remove e1 with - | EInj { e = e1, _; _ } -> - Some - (Expr.unbox - (Expr.make_abs v1 - (Expr.rebox (Bindlib.msubst b2 [| e1 |])) - tys (Expr.pos e2))) - | _ -> assert false) - | _ -> assert false) - | _ -> assert false) - cases1 cases2 - in - EMatch { e = arg; cases; name = n1 } - | EApp { f = EAbs { binder; _ }, _; args; _ } - when binder_vars_used_at_most_once binder - || List.for_all - (function (EVar _ | ELit _), _ -> true | _ -> false) - args -> - (* beta reduction when variables not used, and for variable aliases and - literal *) - Mark.remove (Bindlib.msubst binder (List.map fst args |> Array.of_list)) + | EMatch { name; e; cases } -> simplified_match name e cases mark + | EApp { f; args; tys } -> simplified_apply f args tys | EStructAccess { name; field; e = EStruct { name = name1; fields }, _ } when StructName.equal name name1 -> Mark.remove (StructField.Map.find field fields) @@ -193,12 +212,7 @@ let rec optimize_expr : -> (* No exceptions with condition [true] *) Mark.remove cons - | ( [], - ( ( ELit (LBool false) - | EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ), - _ ) ) -> - (* No exceptions and condition false *) - EEmpty + | [], cond -> simplified_ifthenelse cond cons (EEmpty, mark) mark | ( [except], ( ( ELit (LBool false) | EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ), @@ -206,49 +220,8 @@ let rec optimize_expr : (* Single exception and condition false *) Mark.remove except | excepts, just -> EDefault { excepts; just; cons }) - | EIfThenElse - { - cond = - ( ELit (LBool true), _ - | EAppOp { op = Log _, _; args = [(ELit (LBool true), _)]; _ }, _ ); - etrue; - _; - } -> - Mark.remove etrue - | EIfThenElse - { - cond = - ( ( ELit (LBool false) - | EAppOp { op = Log _, _; args = [(ELit (LBool false), _)]; _ } ), - _ ); - efalse; - _; - } -> - Mark.remove efalse - | EIfThenElse - { - cond; - etrue = - ( ( ELit (LBool btrue) - | EAppOp { op = Log _, _; args = [(ELit (LBool btrue), _)]; _ } ), - _ ); - efalse = - ( ( ELit (LBool bfalse) - | EAppOp { op = Log _, _; args = [(ELit (LBool bfalse), _)]; _ } - ), - _ ); - } -> - if btrue && not bfalse then Mark.remove cond - else if (not btrue) && bfalse then - EAppOp - { - op = Not, Expr.mark_pos mark; - tys = [TLit TBool, Expr.mark_pos mark]; - args = [cond]; - } - (* note: this last call eliminates the condition & might skip log calls - as well *) - else (* btrue = bfalse *) ELit (LBool btrue) + | EIfThenElse { cond; etrue; efalse } -> + simplified_ifthenelse cond etrue efalse mark | EAppOp { op = Op.Fold, _; args = [_f; init; (EArray [], _)]; _ } -> (*reduces a fold with an empty list *) Mark.remove init diff --git a/tests/func/good/scope_call_func_struct_closure.catala_en b/tests/func/good/scope_call_func_struct_closure.catala_en index 640c66df..4e9f7b48 100644 --- a/tests/func/good/scope_call_func_struct_closure.catala_en +++ b/tests/func/good/scope_call_func_struct_closure.catala_en @@ -131,13 +131,9 @@ let scope Foo Foo_in.b_in in let set b : bool = - match - (match (handle_exceptions [b.0 b.1 ()]) with - | ENone → ESome true - | ESome x → ESome x) - with - | ENone → error NoValue - | ESome arg → arg + match (handle_exceptions [b.0 b.1 ()]) with + | ENone → true + | ESome x → x in let set r : Result {