From c4d6220240fcb88603926d7deaf9647a44afa157 Mon Sep 17 00:00:00 2001 From: Nicolas Chataing Date: Wed, 6 Jan 2021 12:41:24 +0100 Subject: [PATCH] Handle extrema operators on collections --- src/catala/catala_surface/ast.ml | 16 +++---- src/catala/catala_surface/desugaring.ml | 54 +++++++++++++++++------ src/catala/catala_surface/lexer.ml | 6 ++- src/catala/catala_surface/lexer_en.ml | 6 ++- src/catala/catala_surface/lexer_fr.ml | 4 ++ src/catala/catala_surface/parser.mly | 10 +++-- tests/test_array/aggregation.catala | 6 ++- tests/test_array/aggregation.catala.B.out | 2 + 8 files changed, 76 insertions(+), 28 deletions(-) diff --git a/src/catala/catala_surface/ast.ml b/src/catala/catala_surface/ast.ml index 3a7f6361..c52ac8fe 100644 --- a/src/catala/catala_surface/ast.ml +++ b/src/catala/catala_surface/ast.ml @@ -89,11 +89,6 @@ type unop = Not | Minus of op_kind type builtin_expression = Cardinal | IntToDec | GetDay | GetMonth | GetYear -type aggregate_func = - | AggregateSum of primitive_typ - | AggregateCount - | AggregateExtremum of bool (* true if max *) * primitive_typ - type literal_date = { literal_date_day : int Pos.marked; literal_date_month : int Pos.marked; @@ -104,8 +99,6 @@ type literal_number = Int of Z.t | Dec of Z.t * Z.t type literal_unit = Percent | Year | Month | Day -type collection_op = Exists | Forall | Aggregate of aggregate_func - type money_amount = { money_amount_units : Z.t; money_amount_cents : Z.t } type literal = @@ -114,7 +107,14 @@ type literal = | MoneyAmount of money_amount | Date of literal_date -type match_case = { +type aggregate_func = + | AggregateSum of primitive_typ + | AggregateCount + | AggregateExtremum of bool (* true if max *) * primitive_typ * expression Pos.marked + +and collection_op = Exists | Forall | Aggregate of aggregate_func + +and match_case = { match_case_pattern : match_case_pattern Pos.marked; match_case_expr : expression Pos.marked; } diff --git a/src/catala/catala_surface/desugaring.ml b/src/catala/catala_surface/desugaring.ml index 30144abf..87c839be 100644 --- a/src/catala/catala_surface/desugaring.ml +++ b/src/catala/catala_surface/desugaring.ml @@ -361,9 +361,7 @@ let rec translate_expr (scope : Scopelang.Ast.ScopeName.t) (ctxt : Name_resoluti Bindlib.box (Scopelang.Ast.ELit (Dcalc.Ast.LMoney Z.zero), Pos.get_position op') | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> Bindlib.box (Scopelang.Ast.ELit (Dcalc.Ast.LDuration Z.zero), Pos.get_position op') - | Ast.Aggregate (Ast.AggregateExtremum _) -> - Errors.raise_spanned_error "Unsupported feature: minimum and maximum" - (Pos.get_position op') + | Ast.Aggregate (Ast.AggregateExtremum (_, _, init)) -> rec_helper init | Ast.Aggregate (Ast.AggregateSum t) -> Errors.raise_spanned_error (Format.asprintf "It is impossible to sum two values of type %a together" @@ -385,6 +383,20 @@ let rec translate_expr (scope : Scopelang.Ast.ScopeName.t) (ctxt : Name_resoluti (translate_expr scope ctxt predicate) acc in + let make_extr_body (cmp_op : Dcalc.Ast.binop) = + Bindlib.box_apply2 + (fun predicate acc -> + ( Scopelang.Ast.EIfThenElse + ( ( Scopelang.Ast.EApp + ( (Scopelang.Ast.EOp (Dcalc.Ast.Binop cmp_op), Pos.get_position op'), + [ acc; predicate ] ), + pos ), + acc, + predicate ), + pos )) + (translate_expr scope ctxt predicate) + acc + in match Pos.unmark op' with | Ast.Exists -> make_body Dcalc.Ast.Or | Ast.Forall -> make_body Dcalc.Ast.And @@ -394,9 +406,17 @@ let rec translate_expr (scope : Scopelang.Ast.ScopeName.t) (ctxt : Name_resoluti | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> make_body (Dcalc.Ast.Add Dcalc.Ast.KDuration) | Ast.Aggregate (Ast.AggregateSum _) -> assert false (* should not happen *) - | Ast.Aggregate (Ast.AggregateExtremum _) -> - Errors.raise_spanned_error "Unsupported feature: minimum and maximum" - (Pos.get_position op') + | Ast.Aggregate (Ast.AggregateExtremum (max_or_min, t, _)) -> + let op_kind = + match t with + | Ast.Integer -> Dcalc.Ast.KInt + | Ast.Decimal -> Dcalc.Ast.KRat + | Ast.Money -> Dcalc.Ast.KMoney + | Ast.Duration -> Dcalc.Ast.KDuration + | _ -> assert false + in + let cmp_op = if max_or_min then Dcalc.Ast.Gt op_kind else Dcalc.Ast.Lt op_kind in + make_extr_body cmp_op | Ast.Aggregate Ast.AggregateCount -> Bindlib.box_apply2 (fun predicate acc -> @@ -434,14 +454,20 @@ let rec translate_expr (scope : Scopelang.Ast.ScopeName.t) (ctxt : Name_resoluti match Pos.unmark op' with | Ast.Exists -> make_f Dcalc.Ast.TBool | Ast.Forall -> make_f Dcalc.Ast.TBool - | Ast.Aggregate (Ast.AggregateSum Ast.Integer) -> make_f Dcalc.Ast.TInt - | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) -> make_f Dcalc.Ast.TRat - | Ast.Aggregate (Ast.AggregateSum Ast.Money) -> make_f Dcalc.Ast.TMoney - | Ast.Aggregate (Ast.AggregateSum Ast.Duration) -> make_f Dcalc.Ast.TDuration - | Ast.Aggregate (Ast.AggregateExtremum _) -> - Errors.raise_spanned_error "Unsupported feature: minimum and maximum" - (Pos.get_position op') - | Ast.Aggregate (Ast.AggregateSum _) -> assert false (* should not happen *) + | Ast.Aggregate (Ast.AggregateSum Ast.Integer) + | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Integer, _)) -> + make_f Dcalc.Ast.TInt + | Ast.Aggregate (Ast.AggregateSum Ast.Decimal) + | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Decimal, _)) -> + make_f Dcalc.Ast.TRat + | Ast.Aggregate (Ast.AggregateSum Ast.Money) + | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Money, _)) -> + make_f Dcalc.Ast.TMoney + | Ast.Aggregate (Ast.AggregateSum Ast.Duration) + | Ast.Aggregate (Ast.AggregateExtremum (_, Ast.Duration, _)) -> + make_f Dcalc.Ast.TDuration + | Ast.Aggregate (Ast.AggregateSum _) | Ast.Aggregate (Ast.AggregateExtremum _) -> + assert false (* should not happen *) | Ast.Aggregate Ast.AggregateCount -> make_f Dcalc.Ast.TInt in Bindlib.box_apply3 diff --git a/src/catala/catala_surface/lexer.ml b/src/catala/catala_surface/lexer.ml index a4d72e63..a57ffa8b 100644 --- a/src/catala/catala_surface/lexer.ml +++ b/src/catala/catala_surface/lexer.ml @@ -78,7 +78,8 @@ let token_list : (string * token) list = ("decreasing", DECREASING); ("increasing", INCREASING); ("maximum", MAXIMUM); - ("minimum", MAXIMUM); + ("minimum", MINIMUM); + ("init", INIT); ("of", OF); ("set", COLLECTION); ("enum", ENUM); @@ -310,6 +311,9 @@ let rec lex_code (lexbuf : lexbuf) : token = | "minimum" -> update_acc lexbuf; MINIMUM + | "init" -> + update_acc lexbuf; + INIT | "number" -> update_acc lexbuf; CARDINAL diff --git a/src/catala/catala_surface/lexer_en.ml b/src/catala/catala_surface/lexer_en.ml index ceb027cf..138d114a 100644 --- a/src/catala/catala_surface/lexer_en.ml +++ b/src/catala/catala_surface/lexer_en.ml @@ -71,7 +71,8 @@ let token_list_en : (string * token) list = ("or", OR); ("not", NOT); ("maximum", MAXIMUM); - ("minimum", MAXIMUM); + ("minimum", MINIMUM); + ("initial", INIT); ("number", CARDINAL); ("year", YEAR); ("month", MONTH); @@ -264,6 +265,9 @@ let rec lex_code_en (lexbuf : lexbuf) : token = | "minimum" -> L.update_acc lexbuf; MINIMUM + | "initial" -> + L.update_acc lexbuf; + INIT | "number" -> L.update_acc lexbuf; CARDINAL diff --git a/src/catala/catala_surface/lexer_fr.ml b/src/catala/catala_surface/lexer_fr.ml index 3c20bba3..36135ced 100644 --- a/src/catala/catala_surface/lexer_fr.ml +++ b/src/catala/catala_surface/lexer_fr.ml @@ -71,6 +71,7 @@ let token_list_fr : (string * token) list = ("nombre", CARDINAL); ("maximum", MAXIMUM); ("minimum", MINIMUM); + ("initial", INIT); ("an", YEAR); ("mois", MONTH); ("jour", DAY); @@ -259,6 +260,9 @@ let rec lex_code_fr (lexbuf : lexbuf) : token = | "minimum" -> L.update_acc lexbuf; MINIMUM + | "initial" -> + L.update_acc lexbuf; + INIT | "entier_vers_d", 0xE9, "cimal" -> L.update_acc lexbuf; INT_TO_DEC diff --git a/src/catala/catala_surface/parser.mly b/src/catala/catala_surface/parser.mly index 73598c70..1bc3189e 100644 --- a/src/catala/catala_surface/parser.mly +++ b/src/catala/catala_surface/parser.mly @@ -63,7 +63,7 @@ %token BEGIN_METADATA END_METADATA MONEY DECIMAL %token UNDER_CONDITION CONSEQUENCE LBRACKET RBRACKET %token LABEL EXCEPTION LSQUARE RSQUARE SEMICOLON -%token INT_TO_DEC MAXIMUM MINIMUM +%token INT_TO_DEC MAXIMUM MINIMUM INIT %token GET_DAY GET_MONTH GET_YEAR %type source_file_or_master @@ -224,8 +224,12 @@ compare_op: | NOT_EQUAL { (Neq, $sloc) } aggregate_func: -| MAXIMUM t = typ_base { (Aggregate (AggregateExtremum (true, Pos.unmark t)), $sloc) } -| MINIMUM t = typ_base { (Aggregate (AggregateExtremum (false, Pos.unmark t)), $sloc) } +| MAXIMUM t = typ_base INIT init = primitive_expression { + (Aggregate (AggregateExtremum (true, Pos.unmark t, init)), $sloc) +} +| MINIMUM t = typ_base INIT init = primitive_expression { + (Aggregate (AggregateExtremum (false, Pos.unmark t, init)), $sloc) +} | SUM t = typ_base { (Aggregate (AggregateSum (Pos.unmark t)), $sloc) } | CARDINAL { (Aggregate AggregateCount, $sloc) } diff --git a/tests/test_array/aggregation.catala b/tests/test_array/aggregation.catala index 3ed64125..c42a1f27 100644 --- a/tests/test_array/aggregation.catala +++ b/tests/test_array/aggregation.catala @@ -9,10 +9,14 @@ scope A: new scope B: param a scope A + param max content money + param min content money param y content money param z content int scope B: + def max := maximum money init $0 for m in a.x of m *$ 2.0 + def min := minimum money init $20 for m in a.x of m +$ $5 def y := sum money for m in a.x of (m +$ $1) def z := number for m in a.x of (m >=$ $8.95) -*/ \ No newline at end of file +*/ diff --git a/tests/test_array/aggregation.catala.B.out b/tests/test_array/aggregation.catala.B.out index 5e8aef0b..483e0879 100644 --- a/tests/test_array/aggregation.catala.B.out +++ b/tests/test_array/aggregation.catala.B.out @@ -1,3 +1,5 @@ [RESULT] Computation successful! Results: +[RESULT] max = $18.00 +[RESULT] min = $5.00 [RESULT] y = $17.20 [RESULT] z = 1