From 7b43f393c5e5990d061f6f749043b6980d0b445f Mon Sep 17 00:00:00 2001 From: Louis Gesbert Date: Fri, 26 Jan 2024 16:02:36 +0100 Subject: [PATCH] Add some optimisations to nested maps --- compiler/shared_ast/optimizations.ml | 95 +++++++++++++++++ runtimes/ocaml/runtime.mli | 3 + tests/test_tuples/good/tuplists.catala_en | 122 ++++++++++++++++++++++ 3 files changed, 220 insertions(+) diff --git a/compiler/shared_ast/optimizations.ml b/compiler/shared_ast/optimizations.ml index b9d1401f..8aca6d2b 100644 --- a/compiler/shared_ast/optimizations.ml +++ b/compiler/shared_ast/optimizations.ml @@ -250,6 +250,101 @@ let rec optimize_expr : | EAppOp { op = Op.Fold; args = [_f; init; (EArray [], _)]; _ } -> (*reduces a fold with an empty list *) Mark.remove init + | EAppOp + { + op = Map; + args = + [ + f1; + ( EAppOp + { + op = Map; + args = [f2; ls]; + tys = [_; ((TArray xty, _) as lsty)]; + }, + m2 ); + ]; + tys = [_; (TArray yty, _)]; + } -> + (* map f (map g l) => map (f o g) l *) + let fg = + let v = + Var.make + (match f2 with + | EAbs { binder; _ }, _ -> (Bindlib.mbinder_names binder).(0) + | _ -> "x") + in + let mty m = + Expr.map_ty (function TArray ty, _ -> ty | _, pos -> TAny, pos) m + in + let x = Expr.evar v (mty (Mark.get ls)) in + Expr.make_abs [| v |] + (Expr.eapp ~f:(Expr.box f1) + ~args:[Expr.eapp ~f:(Expr.box f2) ~args:[x] ~tys:[xty] (mty m2)] + ~tys:[yty] (mty mark)) + [xty] (Expr.pos e) + in + let fg = optimize_expr ctx (Expr.unbox fg) in + let mapl = + Expr.eappop ~op:Map + ~args:[fg; Expr.box ls] + ~tys:[Expr.maybe_ty (Mark.get fg); lsty] + mark + in + Mark.remove (Expr.unbox mapl) + | EAppOp + { + op = Map; + args = + [ + f1; + ( EAppOp + { + op = Map2; + args = [f2; ls1; ls2]; + tys = + [ + _; + ((TArray x1ty, _) as ls1ty); + ((TArray x2ty, _) as ls2ty); + ]; + }, + m2 ); + ]; + tys = [_; (TArray yty, _)]; + } -> + (* map f (map2 g l1 l2) => map2 (f o g) l1 l2 *) + let fg = + let v1, v2 = + match f2 with + | EAbs { binder; _ }, _ -> + let names = Bindlib.mbinder_names binder in + Var.make names.(0), Var.make names.(1) + | _ -> Var.make "x", Var.make "y" + in + let mty m = + Expr.map_ty (function TArray ty, _ -> ty | _, pos -> TAny, pos) m + in + let x1 = Expr.evar v1 (mty (Mark.get ls1)) in + let x2 = Expr.evar v2 (mty (Mark.get ls2)) in + Expr.make_abs [| v1; v2 |] + (Expr.eapp ~f:(Expr.box f1) + ~args: + [ + Expr.eapp ~f:(Expr.box f2) ~args:[x1; x2] ~tys:[x1ty; x2ty] + (mty m2); + ] + ~tys:[yty] (mty mark)) + [x1ty; x2ty] (Expr.pos e) + in + let fg = optimize_expr ctx (Expr.unbox fg) in + let mapl = + Expr.eappop ~op:Map2 + ~args:[fg; Expr.box ls1; Expr.box ls2] + ~tys:[Expr.maybe_ty (Mark.get fg); ls1ty; ls2ty] + mark + in + Mark.remove (Expr.unbox mapl) | EAppOp { op = Op.Fold; diff --git a/runtimes/ocaml/runtime.mli b/runtimes/ocaml/runtime.mli index d6c22a36..b5c4fa9f 100644 --- a/runtimes/ocaml/runtime.mli +++ b/runtimes/ocaml/runtime.mli @@ -353,7 +353,10 @@ module Oper : sig val o_xor : bool -> bool -> bool val o_eq : 'a -> 'a -> bool val o_map : ('a -> 'b) -> 'a array -> 'b array + val o_map2 : ('a -> 'b -> 'c) -> 'a array -> 'b array -> 'c array + (** @raise [NotSameLength] *) + val o_reduce : ('a -> 'a -> 'a) -> 'a -> 'a array -> 'a val o_concat : 'a array -> 'a array -> 'a array val o_filter : ('a -> bool) -> 'a array -> 'a array diff --git a/tests/test_tuples/good/tuplists.catala_en b/tests/test_tuples/good/tuplists.catala_en index 36d53010..bcb89da1 100644 --- a/tests/test_tuples/good/tuplists.catala_en +++ b/tests/test_tuples/good/tuplists.catala_en @@ -93,3 +93,125 @@ r6 = ($170.00, 0.833,333,333,333,333,333,33…) ] ``` + +```catala-test-inline +$ catala interpret -s S -O +[RESULT] Computation successful! Results: +[RESULT] +r1 = + [ + ($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0); + ($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68); + ($170.00, 0.833,333,333,333,333,333,33…) + ] +[RESULT] +r2 = + [ + ($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0); + ($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68); + ($170.00, 0.833,333,333,333,333,333,33…) + ] +[RESULT] +r3 = + [ + ($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0); + ($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68); + ($170.00, 0.833,333,333,333,333,333,33…) + ] +[RESULT] +r4 = + [ + ($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0); + ($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68); + ($170.00, 0.833,333,333,333,333,333,33…) + ] +[RESULT] +r5 = + [ + ($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0); + ($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68); + ($170.00, 0.833,333,333,333,333,333,33…) + ] +[RESULT] +r6 = + [ + ($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0); + ($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68); + ($170.00, 0.833,333,333,333,333,333,33…) + ] +``` + +```catala-test-inline +$ catala dcalc -O +let lis1 : list of decimal = [12.; 13.; 14.; 15.; 16.; 17.] in +let lis2 : list of money = + [¤10.00; ¤1.00; ¤100.00; ¤42.00; ¤17.00; ¤10.00] +in +let lis3 : list of money = + [¤20.00; ¤200.00; ¤10.00; ¤23.00; ¤25.00; ¤12.00] +in +let grok : (decimal, money, money) → (money * decimal) = + λ (dec: decimal) (mon1: money) (mon2: money) → + (mon1 * dec, mon1 / mon2) +in +let tlist : list of (decimal * money * money) = + map2 + (λ (x1: decimal) (x2: (money * money)) → + let a_b_c : (decimal * money * money) = (x1, x2.0, x2.1) in + (a_b_c.0, a_b_c.1, a_b_c.2)) + lis1 + map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3 +in +let S : S_in → S = + λ (S_in: S_in) → + let r1 : list of (money * decimal) = + map (λ (x: (decimal * money * money)) → grok x.0 x.1 x.2) tlist + in + let r2 : list of (money * decimal) = + map2 + (λ (x1: decimal) (x2: (money * money)) → + let x3 : (decimal * money * money) = (x1, x2.0, x2.1) in + grok x3.0 x3.1 x3.2) + lis1 + map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3 + in + let r3 : list of (money * decimal) = + map2 + (λ (x1: decimal) (x2: (money * money)) → + let x_y_z : (decimal * money * money) = (x1, x2.0, x2.1) in + grok x_y_z.0 x_y_z.1 x_y_z.2) + lis1 + map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3 + in + let r4 : list of (money * decimal) = + map (λ (x_y_z: (decimal * money * money)) → + (x_y_z.1 * x_y_z.0, x_y_z.1 / x_y_z.2)) + tlist + in + let r5 : list of (money * decimal) = + map2 + (λ (x1: decimal) (x2: (money * money)) → + let x_y_z : (decimal * money * money) = (x1, x2.0, x2.1) in + (x_y_z.1 * x_y_z.0, x_y_z.1 / x_y_z.2)) + lis1 + map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3 + in + let r6 : list of (money * decimal) = + map2 + (λ (x1: (decimal * money)) (x2: money) → + let xy_z : ((decimal * money) * money) = (x1, x2) in + let xy : (decimal * money) = xy_z.0 in + let z : money = xy_z.1 in + (xy.1 * xy.0, xy.1 / z)) + map2 + (λ (x1: decimal) (x2: money) → + let x_y : (decimal * money) = (x1, x2) in + (x_y.0, x_y.1)) + lis1 + lis2 + lis3 + in + { S r1 = r1; r2 = r2; r3 = r3; r4 = r4; r5 = r5; r6 = r6; } +in +S +```