diff --git a/src/catala/dcalc/ast.ml b/src/catala/dcalc/ast.ml index c1c759b3..fe07c44f 100644 --- a/src/catala/dcalc/ast.ml +++ b/src/catala/dcalc/ast.ml @@ -61,6 +61,7 @@ type ternop = Fold type binop = | And | Or + | Xor | Add of op_kind | Sub of op_kind | Mult of op_kind diff --git a/src/catala/dcalc/ast.mli b/src/catala/dcalc/ast.mli index 298c0125..e7b32f51 100644 --- a/src/catala/dcalc/ast.mli +++ b/src/catala/dcalc/ast.mli @@ -68,6 +68,7 @@ type ternop = Fold type binop = | And | Or + | Xor | Add of op_kind | Sub of op_kind | Mult of op_kind diff --git a/src/catala/dcalc/print.ml b/src/catala/dcalc/print.ml index f9e77db8..3bc53bc7 100644 --- a/src/catala/dcalc/print.ml +++ b/src/catala/dcalc/print.ml @@ -98,7 +98,7 @@ let format_lit (fmt : Format.formatter) (l : lit Pos.marked) : unit = | LMoney e -> ( match !Utils.Cli.locale_lang with | `En -> Format.fprintf fmt "$%s" (Runtime.money_to_string e) - | `Fr -> Format.fprintf fmt "%s €" (Runtime.money_to_string e) ) + | `Fr -> Format.fprintf fmt "%s €" (Runtime.money_to_string e)) | LDate d -> Format.fprintf fmt "%s" (Runtime.date_to_string d) | LDuration d -> Format.fprintf fmt "%s" (Runtime.duration_to_string d) @@ -114,6 +114,7 @@ let format_binop (fmt : Format.formatter) (op : binop Pos.marked) : unit = | Div k -> format_operator fmt (Format.asprintf "/%a" format_op_kind k) | And -> format_operator fmt (Format.asprintf "%s" "&&") | Or -> format_operator fmt (Format.asprintf "%s" "||") + | Xor -> format_operator fmt (Format.asprintf "%s" "xor") | Eq -> format_operator fmt (Format.asprintf "%s" "=") | Neq -> format_operator fmt (Format.asprintf "%s" "!=") | Lt k -> format_operator fmt (Format.asprintf "%s%a" "<" format_op_kind k) @@ -136,7 +137,7 @@ let format_log_entry (fmt : Format.formatter) (entry : log_entry) : unit = let format_unop (fmt : Format.formatter) (op : unop Pos.marked) : unit = Format.fprintf fmt "%s" - ( match Pos.unmark op with + (match Pos.unmark op with | Minus _ -> "-" | Not -> "~" | ErrorOnEmpty -> "error_empty" @@ -150,7 +151,7 @@ let format_unop (fmt : Format.formatter) (op : unop Pos.marked) : unit = | IntToRat -> "int_to_rat" | GetDay -> "get_day" | GetMonth -> "get_month" - | GetYear -> "get_year" ) + | GetYear -> "get_year") let needs_parens (e : expr Pos.marked) : bool = match Pos.unmark e with EAbs _ | ETuple (_, Some _) -> true | _ -> false @@ -196,7 +197,7 @@ let rec format_expr (ctx : Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos. Format.fprintf fmt "%a%a%a%a%a" format_expr e1 format_punctuation "." format_punctuation "\"" Ast.StructFieldName.format_t (fst (List.nth (Ast.StructMap.find s ctx.ctx_structs) n)) - format_punctuation "\"" ) + format_punctuation "\"") | EInj (e, n, en, _ts) -> Format.fprintf fmt "@[%a@ %a@]" Ast.EnumConstructor.format_t (fst (List.nth (Ast.EnumMap.find en ctx.ctx_enums) n)) diff --git a/src/catala/dcalc/typing.ml b/src/catala/dcalc/typing.ml index d795be68..f9000d85 100644 --- a/src/catala/dcalc/typing.ml +++ b/src/catala/dcalc/typing.ml @@ -158,7 +158,7 @@ let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem = let arr x y = UnionFind.make (TArrow (x, y), pos) in match Pos.unmark op with | A.Ternop A.Fold -> arr (arr any2 (arr any any2)) (arr any2 (arr array_any any2)) - | A.Binop (A.And | A.Or) -> arr bt (arr bt bt) + | A.Binop (A.And | A.Or | A.Xor) -> arr bt (arr bt bt) | A.Binop (A.Add KInt | A.Sub KInt | A.Mult KInt | A.Div KInt) -> arr it (arr it it) | A.Binop (A.Add KRat | A.Sub KRat | A.Mult KRat | A.Div KRat) -> arr rt (arr rt rt) | A.Binop (A.Add KMoney | A.Sub KMoney) -> arr mt (arr mt mt) @@ -234,7 +234,7 @@ let rec typecheck_expr_bottom_up (ctx : Ast.decl_ctx) (env : env) (e : A.expr Po | Some t -> t | None -> Errors.raise_spanned_error "Variable not found in the current context" - (Pos.get_position e) ) + (Pos.get_position e)) | ELit (LBool _) -> UnionFind.make (Pos.same_pos_as (TLit TBool) e) | ELit (LInt _) -> UnionFind.make (Pos.same_pos_as (TLit TInt) e) | ELit (LRat _) -> UnionFind.make (Pos.same_pos_as (TLit TRat) e) @@ -258,7 +258,7 @@ let rec typecheck_expr_bottom_up (ctx : Ast.decl_ctx) (env : env) (e : A.expr Po (Format.asprintf "Expression should have a tuple type with at least %d elements but only has %d" n (List.length typs)) - (Pos.get_position e1) ) + (Pos.get_position e1)) | EInj (e1, n, e_name, ts) -> let ts = List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts in let ts_n = @@ -362,7 +362,7 @@ and typecheck_expr_top_down (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.mar | Some tau' -> ignore (unify ctx tau tau') | None -> Errors.raise_spanned_error "Variable not found in the current context" - (Pos.get_position e) ) + (Pos.get_position e)) | ELit (LBool _) -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TBool) e)) | ELit (LInt _) -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TInt) e)) | ELit (LRat _) -> unify ctx tau (UnionFind.make (Pos.same_pos_as (TLit TRat) e)) @@ -387,7 +387,7 @@ and typecheck_expr_top_down (ctx : Ast.decl_ctx) (env : env) (e : A.expr Pos.mar (Format.asprintf "Expression should have a tuple type with at least %d elements but only has %d" n (List.length typs)) - (Pos.get_position e1) ) + (Pos.get_position e1)) | EInj (e1, n, e_name, ts) -> let ts = List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts in let ts_n = diff --git a/src/catala/lcalc/to_ocaml.ml b/src/catala/lcalc/to_ocaml.ml index 4b6e6975..1fb7f95f 100644 --- a/src/catala/lcalc/to_ocaml.ml +++ b/src/catala/lcalc/to_ocaml.ml @@ -56,7 +56,7 @@ let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Pos.marked) : un | And -> Format.fprintf fmt "%s" "&&" | Or -> Format.fprintf fmt "%s" "||" | Eq -> Format.fprintf fmt "%s" "=" - | Neq -> Format.fprintf fmt "%s" "<>" + | Neq | Xor -> Format.fprintf fmt "%s" "<>" | Lt k -> Format.fprintf fmt "%s%a" "<" format_op_kind k | Lte k -> Format.fprintf fmt "%s%a" "<=" format_op_kind k | Gt k -> Format.fprintf fmt "%s%a" ">" format_op_kind k @@ -245,7 +245,7 @@ let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : exp format_with_parens e1 | Some s -> Format.fprintf fmt "%a.%a" format_with_parens e1 format_struct_field_name - (fst (List.nth (Dcalc.Ast.StructMap.find s ctx.ctx_structs) n)) ) + (fst (List.nth (Dcalc.Ast.StructMap.find s ctx.ctx_structs) n))) | EInj (e, n, en, _ts) -> Format.fprintf fmt "@[%a@ %a@]" format_enum_cons_name (fst (List.nth (Dcalc.Ast.EnumMap.find en ctx.ctx_enums) n)) diff --git a/src/catala/surface/lexer.ml b/src/catala/surface/lexer.ml index eb320ff4..43435345 100644 --- a/src/catala/surface/lexer.ml +++ b/src/catala/surface/lexer.ml @@ -288,6 +288,9 @@ let rec lex_code (lexbuf : lexbuf) : token = | "||" -> update_acc lexbuf; OR + | "xor" -> + update_acc lexbuf; + XOR | "not" -> update_acc lexbuf; NOT diff --git a/src/catala/surface/parser.mly b/src/catala/surface/parser.mly index 6a4a7e73..4364ba0d 100644 --- a/src/catala/surface/parser.mly +++ b/src/catala/surface/parser.mly @@ -42,7 +42,7 @@ %token LESSER_DATE GREATER_DATE LESSER_EQUAL_DATE GREATER_EQUAL_DATE %token LESSER_DURATION GREATER_DURATION LESSER_EQUAL_DURATION GREATER_EQUAL_DURATION %token EXISTS IN SUCH THAT -%token DOT AND OR LPAREN RPAREN EQUAL +%token DOT AND OR XOR LPAREN RPAREN EQUAL %token CARDINAL ASSERTION FIXED BY YEAR MONTH DAY %token PLUS MINUS MULT DIV %token PLUSDEC MINUSDEC MULTDEC DIVDEC @@ -294,6 +294,7 @@ struct_or_enum_inject: logical_op: | AND { (And, Pos.from_lpos $sloc) } | OR { (Or, Pos.from_lpos $sloc) } + | XOR { (Neq, Pos.from_lpos $sloc) } logical_unop: | NOT { (Not, Pos.from_lpos $sloc) } diff --git a/tests/test_bool/bad/output/test_xor_with_int.catala.TestXorWithInt.out b/tests/test_bool/bad/output/test_xor_with_int.catala.TestXorWithInt.out new file mode 100644 index 00000000..dcdc4962 --- /dev/null +++ b/tests/test_bool/bad/output/test_xor_with_int.catala.TestXorWithInt.out @@ -0,0 +1,20 @@ +[ERROR] Error during typechecking, incompatible types: +[ERROR] --> integer +[ERROR] --> bool +[ERROR] +[ERROR] Error coming from typechecking the following expression: +[ERROR] No position information +[ERROR] +[ERROR] Type integer coming from expression: +[ERROR] --> test_bool/bad/test_xor_with_int.catala +[ERROR] | +[ERROR] 4 | new scope TestXorWithInt : +[ERROR] | ^^^^^^^^^^^^^^ +[ERROR] + 'xor' should be a boolean operator +[ERROR] +[ERROR] Type bool coming from expression: +[ERROR] --> test_bool/bad/test_xor_with_int.catala +[ERROR] | +[ERROR] 8 | def test_var := 10 xor 20 +[ERROR] | ^^^ +[ERROR] + 'xor' should be a boolean operator diff --git a/tests/test_bool/bad/test_xor_with_int.catala b/tests/test_bool/bad/test_xor_with_int.catala new file mode 100644 index 00000000..5720f3c3 --- /dev/null +++ b/tests/test_bool/bad/test_xor_with_int.catala @@ -0,0 +1,9 @@ +## ['xor' should be a boolean operator] + +```catala +new scope TestXorWithInt : + param test_var content int + +scope TestXorWithInt : + def test_var := 10 xor 20 +``` diff --git a/tests/test_bool/good/output/test_xor.catala.TestXor.out b/tests/test_bool/good/output/test_xor.catala.TestXor.out new file mode 100644 index 00000000..90530526 --- /dev/null +++ b/tests/test_bool/good/output/test_xor.catala.TestXor.out @@ -0,0 +1,5 @@ +[RESULT] Computation successful! Results: +[RESULT] f_xor_f = false +[RESULT] f_xor_t = true +[RESULT] t_xor_f = true +[RESULT] t_xor_t = false diff --git a/tests/test_bool/good/test_xor.catala b/tests/test_bool/good/test_xor.catala new file mode 100644 index 00000000..270038c3 --- /dev/null +++ b/tests/test_bool/good/test_xor.catala @@ -0,0 +1,15 @@ +## [Test all 'xor' combinations] + +```catala +new scope TestXor : + param t_xor_t content bool + param t_xor_f content bool + param f_xor_t content bool + param f_xor_f content bool + +scope TestXor : + def t_xor_t := true xor true + def f_xor_t := false xor true + def t_xor_f := true xor false + def f_xor_f := false xor false +```