mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Add (internally) a map2 operator
This commit is contained in:
parent
7cbf7d6d1b
commit
bc90a7b890
@ -84,6 +84,7 @@ let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit =
|
||||
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur ->
|
||||
Format.pp_print_string fmt "=="
|
||||
| Map -> Format.pp_print_string fmt "list_map"
|
||||
| Map2 -> Format.pp_print_string fmt "list_map2"
|
||||
| Reduce -> Format.pp_print_string fmt "list_reduce"
|
||||
| Filter -> Format.pp_print_string fmt "list_filter"
|
||||
| Fold -> Format.pp_print_string fmt "list_fold_left"
|
||||
|
@ -99,6 +99,7 @@ let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit =
|
||||
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur ->
|
||||
Format.pp_print_string fmt "=="
|
||||
| Map -> Format.pp_print_string fmt "catala_list_map"
|
||||
| Map2 -> Format.pp_print_string fmt "catala_list_map2"
|
||||
| Reduce -> Format.pp_print_string fmt "catala_list_reduce"
|
||||
| Filter -> Format.pp_print_string fmt "catala_list_filter"
|
||||
| Fold -> Format.pp_print_string fmt "catala_list_fold_left"
|
||||
|
@ -314,6 +314,7 @@ module Op = struct
|
||||
(* * polymorphic *)
|
||||
| Eq : < polymorphic ; .. > t
|
||||
| Map : < polymorphic ; .. > t
|
||||
| Map2 : < polymorphic ; .. > t
|
||||
| Concat : < polymorphic ; .. > t
|
||||
| Filter : < polymorphic ; .. > t
|
||||
(* * overloaded *)
|
||||
|
@ -198,6 +198,14 @@ let rec evaluate_operator
|
||||
(Mark.copy e'
|
||||
(EApp { f; args = [e']; tys = [Expr.maybe_ty (Mark.get e')] })))
|
||||
es)
|
||||
| Map2, [f; (EArray es1, _); (EArray es2, _)] ->
|
||||
EArray
|
||||
(List.map2
|
||||
(fun e1 e2 ->
|
||||
evaluate_expr
|
||||
(Mark.add m
|
||||
(EApp { f; args = [e1; e2]; tys = [Expr.maybe_ty (Mark.get e1); Expr.maybe_ty (Mark.get e2)] })))
|
||||
es1 es2)
|
||||
| Reduce, [_; default; (EArray [], _)] -> Mark.remove default
|
||||
| Reduce, [f; _; (EArray (x0 :: xn), _)] ->
|
||||
Mark.remove
|
||||
@ -249,7 +257,7 @@ let rec evaluate_operator
|
||||
];
|
||||
})))
|
||||
init es)
|
||||
| (Length | Log _ | Eq | Map | Concat | Filter | Fold | Reduce), _ -> err ()
|
||||
| (Length | Log _ | Eq | Map | Map2 | Concat | Filter | Fold | Reduce), _ -> err ()
|
||||
| Not, [(ELit (LBool b), _)] -> ELit (LBool (o_not b))
|
||||
| GetDay, [(ELit (LDate d), _)] -> ELit (LInt (o_getDay d))
|
||||
| GetMonth, [(ELit (LDate d), _)] -> ELit (LInt (o_getMonth d))
|
||||
|
@ -45,6 +45,7 @@ let name : type a. a t -> string = function
|
||||
| Xor -> "o_xor"
|
||||
| Eq -> "o_eq"
|
||||
| Map -> "o_map"
|
||||
| Map2 -> "o_map2"
|
||||
| Concat -> "o_concat"
|
||||
| Filter -> "o_filter"
|
||||
| Reduce -> "o_reduce"
|
||||
@ -174,6 +175,7 @@ let compare (type a1 a2) (t1 : a1 t) (t2 : a2 t) =
|
||||
| Xor, Xor
|
||||
| Eq, Eq
|
||||
| Map, Map
|
||||
| Map2, Map2
|
||||
| Concat, Concat
|
||||
| Filter, Filter
|
||||
| Reduce, Reduce
|
||||
@ -259,6 +261,7 @@ let compare (type a1 a2) (t1 : a1 t) (t2 : a2 t) =
|
||||
| Xor, _ -> -1 | _, Xor -> 1
|
||||
| Eq, _ -> -1 | _, Eq -> 1
|
||||
| Map, _ -> -1 | _, Map -> 1
|
||||
| Map2, _ -> -1 | _, Map2 -> 1
|
||||
| Concat, _ -> -1 | _, Concat -> 1
|
||||
| Filter, _ -> -1 | _, Filter -> 1
|
||||
| Reduce, _ -> -1 | _, Reduce -> 1
|
||||
@ -339,7 +342,7 @@ let kind_dispatch :
|
||||
| ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | And
|
||||
| Or | Xor ) as op ->
|
||||
monomorphic op
|
||||
| ( Log _ | Length | Eq | Map | Concat | Filter | Reduce | Fold
|
||||
| ( Log _ | Length | Eq | Map | Map2 | Concat | Filter | Reduce | Fold
|
||||
| HandleDefault | HandleDefaultOpt | FromClosureEnv | ToClosureEnv ) as op
|
||||
->
|
||||
polymorphic op
|
||||
@ -371,7 +374,7 @@ type 'a no_overloads =
|
||||
let translate (t : 'a no_overloads t) : 'b no_overloads t =
|
||||
match t with
|
||||
| ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | And
|
||||
| Or | Xor | HandleDefault | HandleDefaultOpt | Log _ | Length | Eq | Map
|
||||
| Or | Xor | HandleDefault | HandleDefaultOpt | Log _ | Length | Eq | Map | Map2
|
||||
| Concat | Filter | Reduce | Fold | Minus_int | Minus_rat | Minus_mon
|
||||
| Minus_dur | ToRat_int | ToRat_mon | ToMoney_rat | Round_rat | Round_mon
|
||||
| Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur _ | Add_dur_dur
|
||||
|
@ -222,6 +222,7 @@ let operator_to_string : type a. a Op.t -> string =
|
||||
| Xor -> "xor"
|
||||
| Eq -> "="
|
||||
| Map -> "map"
|
||||
| Map2 -> "map2"
|
||||
| Reduce -> "reduce"
|
||||
| Concat -> "++"
|
||||
| Filter -> "filter"
|
||||
@ -306,6 +307,7 @@ let operator_to_shorter_string : type a. a Op.t -> string =
|
||||
| Xor -> "xor"
|
||||
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dur_dur | Eq_dat_dat | Eq -> "="
|
||||
| Map -> "map"
|
||||
| Map2 -> "map2"
|
||||
| Reduce -> "reduce"
|
||||
| Concat -> "++"
|
||||
| Filter -> "filter"
|
||||
@ -407,7 +409,7 @@ module Precedence = struct
|
||||
| Div | Div_int_int | Div_rat_rat | Div_mon_rat | Div_mon_mon
|
||||
| Div_dur_dur ->
|
||||
Op Div
|
||||
| HandleDefault | HandleDefaultOpt | Map | Concat | Filter | Reduce | Fold
|
||||
| HandleDefault | HandleDefaultOpt | Map | Map2 | Concat | Filter | Reduce | Fold
|
||||
| ToClosureEnv | FromClosureEnv ->
|
||||
App)
|
||||
| EApp _ -> App
|
||||
|
@ -287,6 +287,7 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) :
|
||||
let pos = Mark.get op in
|
||||
let any = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in
|
||||
let any2 = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in
|
||||
let any3 = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in
|
||||
let bt = lazy (UnionFind.make (TLit TBool, pos)) in
|
||||
let ut = lazy (UnionFind.make (TLit TUnit, pos)) in
|
||||
let it = lazy (UnionFind.make (TLit TInt, pos)) in
|
||||
@ -302,6 +303,7 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) :
|
||||
| Fold -> [[any2; any] @-> any2; any2; array any] @-> any2
|
||||
| Eq -> [any; any] @-> bt
|
||||
| Map -> [[any] @-> any2; array any] @-> array any2
|
||||
| Map2 -> [[any; any2] @-> any3; array any; array any2] @-> array any3
|
||||
| Filter -> [[any] @-> bt; array any] @-> array any
|
||||
| Reduce -> [[any; any] @-> any; any; array any] @-> any
|
||||
| Concat -> [array any; array any] @-> array any
|
||||
|
@ -50,6 +50,7 @@ exception UncomparableDurations
|
||||
exception IndivisibleDurations
|
||||
exception ImpossibleDate
|
||||
exception NoValueProvided of source_position
|
||||
exception NotSameLength
|
||||
|
||||
(* TODO: register exception printers for the above
|
||||
(Printexc.register_printer) *)
|
||||
@ -660,6 +661,7 @@ module Oper = struct
|
||||
let o_xor : bool -> bool -> bool = ( <> )
|
||||
let o_eq = ( = )
|
||||
let o_map = Array.map
|
||||
let o_map2 f a b = try Array.map2 f a b with Invalid_argument _ -> raise NotSameLength
|
||||
|
||||
let o_reduce f dft a =
|
||||
let len = Array.length a in
|
||||
|
@ -353,6 +353,7 @@ module Oper : sig
|
||||
val o_xor : bool -> bool -> bool
|
||||
val o_eq : 'a -> 'a -> bool
|
||||
val o_map : ('a -> 'b) -> 'a array -> 'b array
|
||||
val o_map2 : ('a -> 'b -> 'c) -> 'a array -> 'b array -> 'c array
|
||||
val o_reduce : ('a -> 'a -> 'a) -> 'a -> 'a array -> 'a
|
||||
val o_concat : 'a array -> 'a array -> 'a array
|
||||
val o_filter : ('a -> bool) -> 'a array -> 'a array
|
||||
|
@ -571,6 +571,8 @@ def list_filter(f: Callable[[Alpha], bool], l: List[Alpha]) -> List[Alpha]:
|
||||
def list_map(f: Callable[[Alpha], Beta], l: List[Alpha]) -> List[Beta]:
|
||||
return [f(i) for i in l]
|
||||
|
||||
def list_map2(f: Callable[[Alpha, Beta], Gamma], l1: List[Alpha], l2: List[Beta]) -> List[Gamma]:
|
||||
return [f(i, j) for i, j in zip(l1, l2, strict=True)]
|
||||
|
||||
def list_reduce(f: Callable[[Alpha, Alpha], Alpha], dft: Alpha, l: List[Alpha]) -> Alpha:
|
||||
if l == []:
|
||||
|
@ -285,6 +285,10 @@ catala_list_map <- function(f, l) {
|
||||
Map(f, l)
|
||||
}
|
||||
#' @export
|
||||
catala_list_map2 <- function(f, l1, l2) {
|
||||
Map(f, l1, l2)
|
||||
}
|
||||
#' @export
|
||||
catala_list_reduce <- function(f, default, l) {
|
||||
if (length(l) == 0) {
|
||||
default
|
||||
|
Loading…
Reference in New Issue
Block a user