Add some optimisations to nested maps

This commit is contained in:
Louis Gesbert 2024-01-26 16:02:36 +01:00
parent 886a2cf914
commit 7b43f393c5
3 changed files with 220 additions and 0 deletions

View File

@ -250,6 +250,101 @@ let rec optimize_expr :
| EAppOp { op = Op.Fold; args = [_f; init; (EArray [], _)]; _ } ->
(*reduces a fold with an empty list *)
Mark.remove init
| EAppOp
{
op = Map;
args =
[
f1;
( EAppOp
{
op = Map;
args = [f2; ls];
tys = [_; ((TArray xty, _) as lsty)];
},
m2 );
];
tys = [_; (TArray yty, _)];
} ->
(* map f (map g l) => map (f o g) l *)
let fg =
let v =
Var.make
(match f2 with
| EAbs { binder; _ }, _ -> (Bindlib.mbinder_names binder).(0)
| _ -> "x")
in
let mty m =
Expr.map_ty (function TArray ty, _ -> ty | _, pos -> TAny, pos) m
in
let x = Expr.evar v (mty (Mark.get ls)) in
Expr.make_abs [| v |]
(Expr.eapp ~f:(Expr.box f1)
~args:[Expr.eapp ~f:(Expr.box f2) ~args:[x] ~tys:[xty] (mty m2)]
~tys:[yty] (mty mark))
[xty] (Expr.pos e)
in
let fg = optimize_expr ctx (Expr.unbox fg) in
let mapl =
Expr.eappop ~op:Map
~args:[fg; Expr.box ls]
~tys:[Expr.maybe_ty (Mark.get fg); lsty]
mark
in
Mark.remove (Expr.unbox mapl)
| EAppOp
{
op = Map;
args =
[
f1;
( EAppOp
{
op = Map2;
args = [f2; ls1; ls2];
tys =
[
_;
((TArray x1ty, _) as ls1ty);
((TArray x2ty, _) as ls2ty);
];
},
m2 );
];
tys = [_; (TArray yty, _)];
} ->
(* map f (map2 g l1 l2) => map2 (f o g) l1 l2 *)
let fg =
let v1, v2 =
match f2 with
| EAbs { binder; _ }, _ ->
let names = Bindlib.mbinder_names binder in
Var.make names.(0), Var.make names.(1)
| _ -> Var.make "x", Var.make "y"
in
let mty m =
Expr.map_ty (function TArray ty, _ -> ty | _, pos -> TAny, pos) m
in
let x1 = Expr.evar v1 (mty (Mark.get ls1)) in
let x2 = Expr.evar v2 (mty (Mark.get ls2)) in
Expr.make_abs [| v1; v2 |]
(Expr.eapp ~f:(Expr.box f1)
~args:
[
Expr.eapp ~f:(Expr.box f2) ~args:[x1; x2] ~tys:[x1ty; x2ty]
(mty m2);
]
~tys:[yty] (mty mark))
[x1ty; x2ty] (Expr.pos e)
in
let fg = optimize_expr ctx (Expr.unbox fg) in
let mapl =
Expr.eappop ~op:Map2
~args:[fg; Expr.box ls1; Expr.box ls2]
~tys:[Expr.maybe_ty (Mark.get fg); ls1ty; ls2ty]
mark
in
Mark.remove (Expr.unbox mapl)
| EAppOp
{
op = Op.Fold;

View File

@ -353,7 +353,10 @@ module Oper : sig
val o_xor : bool -> bool -> bool
val o_eq : 'a -> 'a -> bool
val o_map : ('a -> 'b) -> 'a array -> 'b array
val o_map2 : ('a -> 'b -> 'c) -> 'a array -> 'b array -> 'c array
(** @raise [NotSameLength] *)
val o_reduce : ('a -> 'a -> 'a) -> 'a -> 'a array -> 'a
val o_concat : 'a array -> 'a array -> 'a array
val o_filter : ('a -> bool) -> 'a array -> 'a array

View File

@ -93,3 +93,125 @@ r6 =
($170.00, 0.833,333,333,333,333,333,33…)
]
```
```catala-test-inline
$ catala interpret -s S -O
[RESULT] Computation successful! Results:
[RESULT]
r1 =
[
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
($170.00, 0.833,333,333,333,333,333,33…)
]
[RESULT]
r2 =
[
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
($170.00, 0.833,333,333,333,333,333,33…)
]
[RESULT]
r3 =
[
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
($170.00, 0.833,333,333,333,333,333,33…)
]
[RESULT]
r4 =
[
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
($170.00, 0.833,333,333,333,333,333,33…)
]
[RESULT]
r5 =
[
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
($170.00, 0.833,333,333,333,333,333,33…)
]
[RESULT]
r6 =
[
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
($170.00, 0.833,333,333,333,333,333,33…)
]
```
```catala-test-inline
$ catala dcalc -O
let lis1 : list of decimal = [12.; 13.; 14.; 15.; 16.; 17.] in
let lis2 : list of money =
[¤10.00; ¤1.00; ¤100.00; ¤42.00; ¤17.00; ¤10.00]
in
let lis3 : list of money =
[¤20.00; ¤200.00; ¤10.00; ¤23.00; ¤25.00; ¤12.00]
in
let grok : (decimal, money, money) → (money * decimal) =
λ (dec: decimal) (mon1: money) (mon2: money) →
(mon1 * dec, mon1 / mon2)
in
let tlist : list of (decimal * money * money) =
map2
(λ (x1: decimal) (x2: (money * money)) →
let a_b_c : (decimal * money * money) = (x1, x2.0, x2.1) in
(a_b_c.0, a_b_c.1, a_b_c.2))
lis1
map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3
in
let S : S_in → S =
λ (S_in: S_in) →
let r1 : list of (money * decimal) =
map (λ (x: (decimal * money * money)) → grok x.0 x.1 x.2) tlist
in
let r2 : list of (money * decimal) =
map2
(λ (x1: decimal) (x2: (money * money)) →
let x3 : (decimal * money * money) = (x1, x2.0, x2.1) in
grok x3.0 x3.1 x3.2)
lis1
map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3
in
let r3 : list of (money * decimal) =
map2
(λ (x1: decimal) (x2: (money * money)) →
let x_y_z : (decimal * money * money) = (x1, x2.0, x2.1) in
grok x_y_z.0 x_y_z.1 x_y_z.2)
lis1
map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3
in
let r4 : list of (money * decimal) =
map (λ (x_y_z: (decimal * money * money)) →
(x_y_z.1 * x_y_z.0, x_y_z.1 / x_y_z.2))
tlist
in
let r5 : list of (money * decimal) =
map2
(λ (x1: decimal) (x2: (money * money)) →
let x_y_z : (decimal * money * money) = (x1, x2.0, x2.1) in
(x_y_z.1 * x_y_z.0, x_y_z.1 / x_y_z.2))
lis1
map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3
in
let r6 : list of (money * decimal) =
map2
(λ (x1: (decimal * money)) (x2: money) →
let xy_z : ((decimal * money) * money) = (x1, x2) in
let xy : (decimal * money) = xy_z.0 in
let z : money = xy_z.1 in
(xy.1 * xy.0, xy.1 / z))
map2
(λ (x1: decimal) (x2: money) →
let x_y : (decimal * money) = (x1, x2) in
(x_y.0, x_y.1))
lis1
lis2
lis3
in
{ S r1 = r1; r2 = r2; r3 = r3; r4 = r4; r5 = r5; r6 = r6; }
in
S
```