mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Add some optimisations to nested maps
This commit is contained in:
parent
886a2cf914
commit
7b43f393c5
@ -250,6 +250,101 @@ let rec optimize_expr :
|
||||
| EAppOp { op = Op.Fold; args = [_f; init; (EArray [], _)]; _ } ->
|
||||
(*reduces a fold with an empty list *)
|
||||
Mark.remove init
|
||||
| EAppOp
|
||||
{
|
||||
op = Map;
|
||||
args =
|
||||
[
|
||||
f1;
|
||||
( EAppOp
|
||||
{
|
||||
op = Map;
|
||||
args = [f2; ls];
|
||||
tys = [_; ((TArray xty, _) as lsty)];
|
||||
},
|
||||
m2 );
|
||||
];
|
||||
tys = [_; (TArray yty, _)];
|
||||
} ->
|
||||
(* map f (map g l) => map (f o g) l *)
|
||||
let fg =
|
||||
let v =
|
||||
Var.make
|
||||
(match f2 with
|
||||
| EAbs { binder; _ }, _ -> (Bindlib.mbinder_names binder).(0)
|
||||
| _ -> "x")
|
||||
in
|
||||
let mty m =
|
||||
Expr.map_ty (function TArray ty, _ -> ty | _, pos -> TAny, pos) m
|
||||
in
|
||||
let x = Expr.evar v (mty (Mark.get ls)) in
|
||||
Expr.make_abs [| v |]
|
||||
(Expr.eapp ~f:(Expr.box f1)
|
||||
~args:[Expr.eapp ~f:(Expr.box f2) ~args:[x] ~tys:[xty] (mty m2)]
|
||||
~tys:[yty] (mty mark))
|
||||
[xty] (Expr.pos e)
|
||||
in
|
||||
let fg = optimize_expr ctx (Expr.unbox fg) in
|
||||
let mapl =
|
||||
Expr.eappop ~op:Map
|
||||
~args:[fg; Expr.box ls]
|
||||
~tys:[Expr.maybe_ty (Mark.get fg); lsty]
|
||||
mark
|
||||
in
|
||||
Mark.remove (Expr.unbox mapl)
|
||||
| EAppOp
|
||||
{
|
||||
op = Map;
|
||||
args =
|
||||
[
|
||||
f1;
|
||||
( EAppOp
|
||||
{
|
||||
op = Map2;
|
||||
args = [f2; ls1; ls2];
|
||||
tys =
|
||||
[
|
||||
_;
|
||||
((TArray x1ty, _) as ls1ty);
|
||||
((TArray x2ty, _) as ls2ty);
|
||||
];
|
||||
},
|
||||
m2 );
|
||||
];
|
||||
tys = [_; (TArray yty, _)];
|
||||
} ->
|
||||
(* map f (map2 g l1 l2) => map2 (f o g) l1 l2 *)
|
||||
let fg =
|
||||
let v1, v2 =
|
||||
match f2 with
|
||||
| EAbs { binder; _ }, _ ->
|
||||
let names = Bindlib.mbinder_names binder in
|
||||
Var.make names.(0), Var.make names.(1)
|
||||
| _ -> Var.make "x", Var.make "y"
|
||||
in
|
||||
let mty m =
|
||||
Expr.map_ty (function TArray ty, _ -> ty | _, pos -> TAny, pos) m
|
||||
in
|
||||
let x1 = Expr.evar v1 (mty (Mark.get ls1)) in
|
||||
let x2 = Expr.evar v2 (mty (Mark.get ls2)) in
|
||||
Expr.make_abs [| v1; v2 |]
|
||||
(Expr.eapp ~f:(Expr.box f1)
|
||||
~args:
|
||||
[
|
||||
Expr.eapp ~f:(Expr.box f2) ~args:[x1; x2] ~tys:[x1ty; x2ty]
|
||||
(mty m2);
|
||||
]
|
||||
~tys:[yty] (mty mark))
|
||||
[x1ty; x2ty] (Expr.pos e)
|
||||
in
|
||||
let fg = optimize_expr ctx (Expr.unbox fg) in
|
||||
let mapl =
|
||||
Expr.eappop ~op:Map2
|
||||
~args:[fg; Expr.box ls1; Expr.box ls2]
|
||||
~tys:[Expr.maybe_ty (Mark.get fg); ls1ty; ls2ty]
|
||||
mark
|
||||
in
|
||||
Mark.remove (Expr.unbox mapl)
|
||||
| EAppOp
|
||||
{
|
||||
op = Op.Fold;
|
||||
|
@ -353,7 +353,10 @@ module Oper : sig
|
||||
val o_xor : bool -> bool -> bool
|
||||
val o_eq : 'a -> 'a -> bool
|
||||
val o_map : ('a -> 'b) -> 'a array -> 'b array
|
||||
|
||||
val o_map2 : ('a -> 'b -> 'c) -> 'a array -> 'b array -> 'c array
|
||||
(** @raise [NotSameLength] *)
|
||||
|
||||
val o_reduce : ('a -> 'a -> 'a) -> 'a -> 'a array -> 'a
|
||||
val o_concat : 'a array -> 'a array -> 'a array
|
||||
val o_filter : ('a -> bool) -> 'a array -> 'a array
|
||||
|
@ -93,3 +93,125 @@ r6 =
|
||||
($170.00, 0.833,333,333,333,333,333,33…)
|
||||
]
|
||||
```
|
||||
|
||||
```catala-test-inline
|
||||
$ catala interpret -s S -O
|
||||
[RESULT] Computation successful! Results:
|
||||
[RESULT]
|
||||
r1 =
|
||||
[
|
||||
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
|
||||
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
|
||||
($170.00, 0.833,333,333,333,333,333,33…)
|
||||
]
|
||||
[RESULT]
|
||||
r2 =
|
||||
[
|
||||
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
|
||||
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
|
||||
($170.00, 0.833,333,333,333,333,333,33…)
|
||||
]
|
||||
[RESULT]
|
||||
r3 =
|
||||
[
|
||||
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
|
||||
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
|
||||
($170.00, 0.833,333,333,333,333,333,33…)
|
||||
]
|
||||
[RESULT]
|
||||
r4 =
|
||||
[
|
||||
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
|
||||
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
|
||||
($170.00, 0.833,333,333,333,333,333,33…)
|
||||
]
|
||||
[RESULT]
|
||||
r5 =
|
||||
[
|
||||
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
|
||||
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
|
||||
($170.00, 0.833,333,333,333,333,333,33…)
|
||||
]
|
||||
[RESULT]
|
||||
r6 =
|
||||
[
|
||||
($120.00, 0.5); ($13.00, 0.005); ($1,400.00, 10.0);
|
||||
($630.00, 1.826,086,956,521,739,130,4…); ($272.00, 0.68);
|
||||
($170.00, 0.833,333,333,333,333,333,33…)
|
||||
]
|
||||
```
|
||||
|
||||
```catala-test-inline
|
||||
$ catala dcalc -O
|
||||
let lis1 : list of decimal = [12.; 13.; 14.; 15.; 16.; 17.] in
|
||||
let lis2 : list of money =
|
||||
[¤10.00; ¤1.00; ¤100.00; ¤42.00; ¤17.00; ¤10.00]
|
||||
in
|
||||
let lis3 : list of money =
|
||||
[¤20.00; ¤200.00; ¤10.00; ¤23.00; ¤25.00; ¤12.00]
|
||||
in
|
||||
let grok : (decimal, money, money) → (money * decimal) =
|
||||
λ (dec: decimal) (mon1: money) (mon2: money) →
|
||||
(mon1 * dec, mon1 / mon2)
|
||||
in
|
||||
let tlist : list of (decimal * money * money) =
|
||||
map2
|
||||
(λ (x1: decimal) (x2: (money * money)) →
|
||||
let a_b_c : (decimal * money * money) = (x1, x2.0, x2.1) in
|
||||
(a_b_c.0, a_b_c.1, a_b_c.2))
|
||||
lis1
|
||||
map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3
|
||||
in
|
||||
let S : S_in → S =
|
||||
λ (S_in: S_in) →
|
||||
let r1 : list of (money * decimal) =
|
||||
map (λ (x: (decimal * money * money)) → grok x.0 x.1 x.2) tlist
|
||||
in
|
||||
let r2 : list of (money * decimal) =
|
||||
map2
|
||||
(λ (x1: decimal) (x2: (money * money)) →
|
||||
let x3 : (decimal * money * money) = (x1, x2.0, x2.1) in
|
||||
grok x3.0 x3.1 x3.2)
|
||||
lis1
|
||||
map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3
|
||||
in
|
||||
let r3 : list of (money * decimal) =
|
||||
map2
|
||||
(λ (x1: decimal) (x2: (money * money)) →
|
||||
let x_y_z : (decimal * money * money) = (x1, x2.0, x2.1) in
|
||||
grok x_y_z.0 x_y_z.1 x_y_z.2)
|
||||
lis1
|
||||
map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3
|
||||
in
|
||||
let r4 : list of (money * decimal) =
|
||||
map (λ (x_y_z: (decimal * money * money)) →
|
||||
(x_y_z.1 * x_y_z.0, x_y_z.1 / x_y_z.2))
|
||||
tlist
|
||||
in
|
||||
let r5 : list of (money * decimal) =
|
||||
map2
|
||||
(λ (x1: decimal) (x2: (money * money)) →
|
||||
let x_y_z : (decimal * money * money) = (x1, x2.0, x2.1) in
|
||||
(x_y_z.1 * x_y_z.0, x_y_z.1 / x_y_z.2))
|
||||
lis1
|
||||
map2 (λ (x1: money) (x2: money) → (x1, x2)) lis2 lis3
|
||||
in
|
||||
let r6 : list of (money * decimal) =
|
||||
map2
|
||||
(λ (x1: (decimal * money)) (x2: money) →
|
||||
let xy_z : ((decimal * money) * money) = (x1, x2) in
|
||||
let xy : (decimal * money) = xy_z.0 in
|
||||
let z : money = xy_z.1 in
|
||||
(xy.1 * xy.0, xy.1 / z))
|
||||
map2
|
||||
(λ (x1: decimal) (x2: money) →
|
||||
let x_y : (decimal * money) = (x1, x2) in
|
||||
(x_y.0, x_y.1))
|
||||
lis1
|
||||
lis2
|
||||
lis3
|
||||
in
|
||||
{ S r1 = r1; r2 = r2; r3 = r3; r4 = r4; r5 = r5; r6 = r6; }
|
||||
in
|
||||
S
|
||||
```
|
||||
|
Loading…
Reference in New Issue
Block a user