Add (internally) a map2 operator

This commit is contained in:
Louis Gesbert 2024-01-24 15:36:51 +01:00
parent 7cbf7d6d1b
commit bc90a7b890
11 changed files with 31 additions and 4 deletions

View File

@ -84,6 +84,7 @@ let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit =
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur ->
Format.pp_print_string fmt "=="
| Map -> Format.pp_print_string fmt "list_map"
| Map2 -> Format.pp_print_string fmt "list_map2"
| Reduce -> Format.pp_print_string fmt "list_reduce"
| Filter -> Format.pp_print_string fmt "list_filter"
| Fold -> Format.pp_print_string fmt "list_fold_left"

View File

@ -99,6 +99,7 @@ let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit =
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur ->
Format.pp_print_string fmt "=="
| Map -> Format.pp_print_string fmt "catala_list_map"
| Map2 -> Format.pp_print_string fmt "catala_list_map2"
| Reduce -> Format.pp_print_string fmt "catala_list_reduce"
| Filter -> Format.pp_print_string fmt "catala_list_filter"
| Fold -> Format.pp_print_string fmt "catala_list_fold_left"

View File

@ -314,6 +314,7 @@ module Op = struct
(* * polymorphic *)
| Eq : < polymorphic ; .. > t
| Map : < polymorphic ; .. > t
| Map2 : < polymorphic ; .. > t
| Concat : < polymorphic ; .. > t
| Filter : < polymorphic ; .. > t
(* * overloaded *)

View File

@ -198,6 +198,14 @@ let rec evaluate_operator
(Mark.copy e'
(EApp { f; args = [e']; tys = [Expr.maybe_ty (Mark.get e')] })))
es)
| Map2, [f; (EArray es1, _); (EArray es2, _)] ->
EArray
(List.map2
(fun e1 e2 ->
evaluate_expr
(Mark.add m
(EApp { f; args = [e1; e2]; tys = [Expr.maybe_ty (Mark.get e1); Expr.maybe_ty (Mark.get e2)] })))
es1 es2)
| Reduce, [_; default; (EArray [], _)] -> Mark.remove default
| Reduce, [f; _; (EArray (x0 :: xn), _)] ->
Mark.remove
@ -249,7 +257,7 @@ let rec evaluate_operator
];
})))
init es)
| (Length | Log _ | Eq | Map | Concat | Filter | Fold | Reduce), _ -> err ()
| (Length | Log _ | Eq | Map | Map2 | Concat | Filter | Fold | Reduce), _ -> err ()
| Not, [(ELit (LBool b), _)] -> ELit (LBool (o_not b))
| GetDay, [(ELit (LDate d), _)] -> ELit (LInt (o_getDay d))
| GetMonth, [(ELit (LDate d), _)] -> ELit (LInt (o_getMonth d))

View File

@ -45,6 +45,7 @@ let name : type a. a t -> string = function
| Xor -> "o_xor"
| Eq -> "o_eq"
| Map -> "o_map"
| Map2 -> "o_map2"
| Concat -> "o_concat"
| Filter -> "o_filter"
| Reduce -> "o_reduce"
@ -174,6 +175,7 @@ let compare (type a1 a2) (t1 : a1 t) (t2 : a2 t) =
| Xor, Xor
| Eq, Eq
| Map, Map
| Map2, Map2
| Concat, Concat
| Filter, Filter
| Reduce, Reduce
@ -259,6 +261,7 @@ let compare (type a1 a2) (t1 : a1 t) (t2 : a2 t) =
| Xor, _ -> -1 | _, Xor -> 1
| Eq, _ -> -1 | _, Eq -> 1
| Map, _ -> -1 | _, Map -> 1
| Map2, _ -> -1 | _, Map2 -> 1
| Concat, _ -> -1 | _, Concat -> 1
| Filter, _ -> -1 | _, Filter -> 1
| Reduce, _ -> -1 | _, Reduce -> 1
@ -339,7 +342,7 @@ let kind_dispatch :
| ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | And
| Or | Xor ) as op ->
monomorphic op
| ( Log _ | Length | Eq | Map | Concat | Filter | Reduce | Fold
| ( Log _ | Length | Eq | Map | Map2 | Concat | Filter | Reduce | Fold
| HandleDefault | HandleDefaultOpt | FromClosureEnv | ToClosureEnv ) as op
->
polymorphic op
@ -371,7 +374,7 @@ type 'a no_overloads =
let translate (t : 'a no_overloads t) : 'b no_overloads t =
match t with
| ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | And
| Or | Xor | HandleDefault | HandleDefaultOpt | Log _ | Length | Eq | Map
| Or | Xor | HandleDefault | HandleDefaultOpt | Log _ | Length | Eq | Map | Map2
| Concat | Filter | Reduce | Fold | Minus_int | Minus_rat | Minus_mon
| Minus_dur | ToRat_int | ToRat_mon | ToMoney_rat | Round_rat | Round_mon
| Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur _ | Add_dur_dur

View File

@ -222,6 +222,7 @@ let operator_to_string : type a. a Op.t -> string =
| Xor -> "xor"
| Eq -> "="
| Map -> "map"
| Map2 -> "map2"
| Reduce -> "reduce"
| Concat -> "++"
| Filter -> "filter"
@ -306,6 +307,7 @@ let operator_to_shorter_string : type a. a Op.t -> string =
| Xor -> "xor"
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dur_dur | Eq_dat_dat | Eq -> "="
| Map -> "map"
| Map2 -> "map2"
| Reduce -> "reduce"
| Concat -> "++"
| Filter -> "filter"
@ -407,7 +409,7 @@ module Precedence = struct
| Div | Div_int_int | Div_rat_rat | Div_mon_rat | Div_mon_mon
| Div_dur_dur ->
Op Div
| HandleDefault | HandleDefaultOpt | Map | Concat | Filter | Reduce | Fold
| HandleDefault | HandleDefaultOpt | Map | Map2 | Concat | Filter | Reduce | Fold
| ToClosureEnv | FromClosureEnv ->
App)
| EApp _ -> App

View File

@ -287,6 +287,7 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) :
let pos = Mark.get op in
let any = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in
let any2 = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in
let any3 = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in
let bt = lazy (UnionFind.make (TLit TBool, pos)) in
let ut = lazy (UnionFind.make (TLit TUnit, pos)) in
let it = lazy (UnionFind.make (TLit TInt, pos)) in
@ -302,6 +303,7 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) :
| Fold -> [[any2; any] @-> any2; any2; array any] @-> any2
| Eq -> [any; any] @-> bt
| Map -> [[any] @-> any2; array any] @-> array any2
| Map2 -> [[any; any2] @-> any3; array any; array any2] @-> array any3
| Filter -> [[any] @-> bt; array any] @-> array any
| Reduce -> [[any; any] @-> any; any; array any] @-> any
| Concat -> [array any; array any] @-> array any

View File

@ -50,6 +50,7 @@ exception UncomparableDurations
exception IndivisibleDurations
exception ImpossibleDate
exception NoValueProvided of source_position
exception NotSameLength
(* TODO: register exception printers for the above
(Printexc.register_printer) *)
@ -660,6 +661,7 @@ module Oper = struct
let o_xor : bool -> bool -> bool = ( <> )
let o_eq = ( = )
let o_map = Array.map
let o_map2 f a b = try Array.map2 f a b with Invalid_argument _ -> raise NotSameLength
let o_reduce f dft a =
let len = Array.length a in

View File

@ -353,6 +353,7 @@ module Oper : sig
val o_xor : bool -> bool -> bool
val o_eq : 'a -> 'a -> bool
val o_map : ('a -> 'b) -> 'a array -> 'b array
val o_map2 : ('a -> 'b -> 'c) -> 'a array -> 'b array -> 'c array
val o_reduce : ('a -> 'a -> 'a) -> 'a -> 'a array -> 'a
val o_concat : 'a array -> 'a array -> 'a array
val o_filter : ('a -> bool) -> 'a array -> 'a array

View File

@ -571,6 +571,8 @@ def list_filter(f: Callable[[Alpha], bool], l: List[Alpha]) -> List[Alpha]:
def list_map(f: Callable[[Alpha], Beta], l: List[Alpha]) -> List[Beta]:
return [f(i) for i in l]
def list_map2(f: Callable[[Alpha, Beta], Gamma], l1: List[Alpha], l2: List[Beta]) -> List[Gamma]:
return [f(i, j) for i, j in zip(l1, l2, strict=True)]
def list_reduce(f: Callable[[Alpha, Alpha], Alpha], dft: Alpha, l: List[Alpha]) -> Alpha:
if l == []:

View File

@ -285,6 +285,10 @@ catala_list_map <- function(f, l) {
Map(f, l)
}
#' @export
catala_list_map2 <- function(f, l1, l2) {
Map(f, l1, l2)
}
#' @export
catala_list_reduce <- function(f, default, l) {
if (length(l) == 0) {
default