Add overloaded operators for the common operations

This uses the same disambiguation mechanism put in place for
structures, calling the typer on individual rules on the desugared AST
to propagate types, in order to resolve ambiguous operators like `+`
to their strongly typed counterparts (`+!`, `+.`, `+$`, `+@`, `+$`) in
the translation to scopelang.

The patch includes some normalisation of the definition of all the
operators, and classifies them based on their typing policy instead of
their arity. It also adds a little more flexibility:
- a couple new operators, like `-` on date and duration
- optional type annotation on some aggregation constructions

The `Shared_ast` lib is also lightly restructured, with the `Expr`
module split into `Type`, `Operator` and `Expr`.
This commit is contained in:
Louis Gesbert 2022-11-29 09:47:53 +01:00
parent 5bcc0a65eb
commit fea01cfe4c
53 changed files with 14482 additions and 13013 deletions

View File

@ -48,6 +48,12 @@ let to_camel_case (s : string) : string =
last_was_underscore := is_underscore);
!out
let remove_prefix ~prefix s =
if starts_with ~prefix s then
let plen = length prefix in
sub s plen (length s - plen)
else s
let format_t = Format.pp_print_string
module Set = Set.Make (Stdlib.String)

View File

@ -39,4 +39,10 @@ val to_camel_case : string -> string
(** Converts snake_case into CamlCase after removing Remove all diacritics on
Latin letters. *)
val remove_prefix : prefix:string -> string -> string
(** [remove_prefix ~prefix str] returns
- if [str] starts with [prefix], a string [s] such that [prefix ^ s = str]
- otherwise, [str] unchanged *)
val format_t : Format.formatter -> string -> unit

View File

@ -148,7 +148,7 @@ let tag_with_log_entry
(l : log_entry)
(markings : Uid.MarkedString.info list) : 'm Ast.expr boxed =
let m = mark_tany (Marked.get_mark e) (Expr.pos e) in
Expr.eapp (Expr.eop (Unop (Log (l, markings))) m) [e] m
Expr.eapp (Expr.eop (Log (l, markings)) [TAny, Expr.pos e] m) [e] m
(* In a list of exceptions, it is normally an error if more than a single one
apply at the same time. This relaxes this constraint slightly, allowing a
@ -417,7 +417,7 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
Expr.eifthenelse (translate_expr ctx cond) (translate_expr ctx etrue)
(translate_expr ctx efalse)
m
| EOp op -> Expr.eop (Expr.translate_op op) m
| EOp { op; tys } -> Expr.eop (Operator.translate op) tys m
| EErrorOnEmpty e' -> Expr.eerroronempty (translate_expr ctx e') m
| EArray es -> Expr.earray (List.map (translate_expr ctx) es) m

View File

@ -29,278 +29,114 @@ let log_indent = ref 0
(** {1 Evaluation} *)
let rec evaluate_operator
(ctx : decl_ctx)
(op : dcalc operator)
(pos : Pos.t)
(args : 'm Ast.expr list) : 'm Ast.naked_expr =
(* Try to apply [div] and if a [Division_by_zero] exceptions is catched, use
[op] to raise multispanned errors. *)
let apply_div_or_raise_err (div : unit -> 'm Ast.naked_expr) :
'm Ast.naked_expr =
try div ()
with Division_by_zero ->
let print_log ctx entry infos pos e =
if !Cli.trace_flag then
match entry with
| VarDef _ ->
(* TODO: this usage of Format is broken, Formatting requires that all is
formatted in one pass, without going through intermediate "%s" *)
Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos
(match Marked.unmark e with
| EAbs _ -> Cli.with_style [ANSITerminal.green] "<function>"
| _ ->
let expr_str =
Format.asprintf "%a" (Expr.format ctx ~debug:false) e
in
let expr_str =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
~subst:(fun _ -> " ")
expr_str
in
Cli.with_style [ANSITerminal.green] "%s" expr_str)
| PosRecordIfTrueBool -> (
match pos <> Pos.no_pos, Marked.unmark e with
| true, ELit (LBool true) ->
Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" Print.log_entry entry
(Cli.with_style [ANSITerminal.green] "Definition applied")
(Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) (fun _ ->
Format.asprintf "%*s" (!log_indent * 2) ""))
| _ -> ())
| BeginCall ->
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos;
log_indent := !log_indent + 1
| EndCall ->
log_indent := !log_indent - 1;
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos
(* Todo: this should be handled early when resolving overloads *)
let rec handle_eq ctx pos e1 e2 =
let open Runtime.Oper in
match e1, e2 with
| ELit LUnit, ELit LUnit -> true
| ELit (LBool b1), ELit (LBool b2) -> not (o_xor b1 b2)
| ELit (LInt x1), ELit (LInt x2) -> o_eq_int_int x1 x2
| ELit (LRat x1), ELit (LRat x2) -> o_eq_rat_rat x1 x2
| ELit (LMoney x1), ELit (LMoney x2) -> o_eq_mon_mon x1 x2
| ELit (LDuration x1), ELit (LDuration x2) -> o_eq_dur_dur x1 x2
| ELit (LDate x1), ELit (LDate x2) -> o_eq_dat_dat x1 x2
| EArray es1, EArray es2 -> (
try
List.for_all2
(fun e1 e2 ->
match evaluate_operator ctx Eq pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
with Invalid_argument _ -> false)
| EStruct { fields = es1; name = s1 }, EStruct { fields = es2; name = s2 } ->
StructName.equal s1 s2
&& StructField.Map.equal
(fun e1 e2 ->
match evaluate_operator ctx Eq pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
| ( EInj { e = e1; cons = i1; name = en1 },
EInj { e = e2; cons = i2; name = en2 } ) -> (
try
EnumName.equal en1 en2
&& EnumConstructor.equal i1 i2
&&
match evaluate_operator ctx Eq pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *)
with Invalid_argument _ -> false)
| _, _ -> false (* comparing anything else return false *)
and evaluate_operator :
type k.
decl_ctx ->
(dcalc, k) operator ->
Pos.t ->
'm Ast.expr list ->
'm Ast.naked_expr =
fun ctx op pos args ->
let protect f x y =
let get_binop_args_pos = function
| (arg0 :: arg1 :: _ : 'm Ast.expr list) ->
[None, Expr.pos arg0; None, Expr.pos arg1]
| _ -> assert false
in
try f x y with
| Division_by_zero ->
Errors.raise_multispanned_error
[
Some "The division operator:", pos;
Some "The null denominator:", Expr.pos (List.nth args 1);
]
"division by zero at runtime"
in
let get_binop_args_pos = function
| (arg0 :: arg1 :: _ : 'm Ast.expr list) ->
[None, Expr.pos arg0; None, Expr.pos arg1]
| _ -> assert false
in
(* Try to apply [cmp] and if a [UncomparableDurations] exceptions is catched,
use [args] to raise multispanned errors. *)
let apply_cmp_or_raise_err
(cmp : unit -> 'm Ast.naked_expr)
(args : 'm Ast.expr list) : 'm Ast.naked_expr =
try cmp ()
with Runtime.UncomparableDurations ->
| Runtime.UncomparableDurations ->
Errors.raise_multispanned_error (get_binop_args_pos args)
"Cannot compare together durations that cannot be converted to a \
precise number of days"
in
match op, List.map Marked.unmark args with
| Ternop Fold, [_f; _init; EArray es] ->
Marked.unmark
(List.fold_left
(fun acc e' ->
evaluate_expr ctx
(Marked.same_mark_as
(EApp { f = List.nth args 0; args = [acc; e'] })
e'))
(List.nth args 1) es)
| Binop And, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 && b2))
| Binop Or, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 || b2))
| Binop Xor, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 <> b2))
| Binop (Add KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LInt Runtime.(i1 +! i2))
| Binop (Sub KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LInt Runtime.(i1 -! i2))
| Binop (Mult KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LInt Runtime.(i1 *! i2))
| Binop (Div KInt), [ELit (LInt i1); ELit (LInt i2)] ->
apply_div_or_raise_err (fun _ -> ELit (LInt Runtime.(i1 /! i2)))
| Binop (Add KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LRat Runtime.(i1 +& i2))
| Binop (Sub KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LRat Runtime.(i1 -& i2))
| Binop (Mult KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LRat Runtime.(i1 *& i2))
| Binop (Div KRat), [ELit (LRat i1); ELit (LRat i2)] ->
apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(i1 /& i2)))
| Binop (Add KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LMoney Runtime.(m1 +$ m2))
| Binop (Sub KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LMoney Runtime.(m1 -$ m2))
| Binop (Mult KMoney), [ELit (LMoney m1); ELit (LRat m2)] ->
ELit (LMoney Runtime.(m1 *$ m2))
| Binop (Div KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
apply_div_or_raise_err (fun _ -> ELit (LRat Runtime.(m1 /$ m2)))
| Binop (Add KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
ELit (LDuration Runtime.(d1 +^ d2))
| Binop (Sub KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
ELit (LDuration Runtime.(d1 -^ d2))
| Binop (Sub KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LDuration Runtime.(d1 -@ d2))
| Binop (Add KDate), [ELit (LDate d1); ELit (LDuration d2)] ->
ELit (LDate Runtime.(d1 +@ d2))
| Binop (Mult KDuration), [ELit (LDuration d1); ELit (LInt i1)] ->
ELit (LDuration Runtime.(d1 *^ i1))
| Binop (Lt KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 <! i2))
| Binop (Lte KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 <=! i2))
| Binop (Gt KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 >! i2))
| Binop (Gte KInt), [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 >=! i2))
| Binop (Lt KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 <& i2))
| Binop (Lte KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 <=& i2))
| Binop (Gt KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 >& i2))
| Binop (Gte KRat), [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 >=& i2))
| Binop (Lt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 <$ m2))
| Binop (Lte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 <=$ m2))
| Binop (Gt KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 >$ m2))
| Binop (Gte KMoney), [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 >=$ m2))
| Binop (Lt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <^ d2))) args
| Binop (Lte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 <=^ d2))) args
| Binop (Gt KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >^ d2))) args
| Binop (Gte KDuration), [ELit (LDuration d1); ELit (LDuration d2)] ->
apply_cmp_or_raise_err (fun _ -> ELit (LBool Runtime.(d1 >=^ d2))) args
| Binop (Lt KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 <@ d2))
| Binop (Lte KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 <=@ d2))
| Binop (Gt KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 >@ d2))
| Binop (Gte KDate), [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 >=@ d2))
| Binop Eq, [ELit LUnit; ELit LUnit] -> ELit (LBool true)
| Binop Eq, [ELit (LDuration d1); ELit (LDuration d2)] ->
ELit (LBool Runtime.(d1 =^ d2))
| Binop Eq, [ELit (LDate d1); ELit (LDate d2)] ->
ELit (LBool Runtime.(d1 =@ d2))
| Binop Eq, [ELit (LMoney m1); ELit (LMoney m2)] ->
ELit (LBool Runtime.(m1 =$ m2))
| Binop Eq, [ELit (LRat i1); ELit (LRat i2)] ->
ELit (LBool Runtime.(i1 =& i2))
| Binop Eq, [ELit (LInt i1); ELit (LInt i2)] ->
ELit (LBool Runtime.(i1 =! i2))
| Binop Eq, [ELit (LBool b1); ELit (LBool b2)] -> ELit (LBool (b1 = b2))
| Binop Eq, [EArray es1; EArray es2] ->
ELit
(LBool
(try
List.for_all2
(fun e1 e2 ->
match evaluate_operator ctx op pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2
with Invalid_argument _ -> false))
| ( Binop Eq,
[EStruct { fields = es1; name = s1 }; EStruct { fields = es2; name = s2 }]
) ->
ELit
(LBool
(StructName.equal s1 s2
&& StructField.Map.equal
(fun e1 e2 ->
match evaluate_operator ctx op pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *))
es1 es2))
| ( Binop Eq,
[
EInj { e = e1; cons = i1; name = en1 };
EInj { e = e2; cons = i2; name = en2 };
] ) ->
ELit
(LBool
(try
EnumName.equal en1 en2
&& EnumConstructor.equal i1 i2
&&
match evaluate_operator ctx op pos [e1; e2] with
| ELit (LBool b) -> b
| _ -> assert false
(* should not happen *)
with Invalid_argument _ -> false))
| Binop Eq, [_; _] ->
ELit (LBool false) (* comparing anything else return false *)
| Binop Neq, [_; _] -> (
match evaluate_operator ctx (Binop Eq) pos args with
| ELit (LBool b) -> ELit (LBool (not b))
| _ -> assert false (*should not happen *))
| Binop Concat, [EArray es1; EArray es2] -> EArray (es1 @ es2)
| Binop Map, [_; EArray es] ->
EArray
(List.map
(fun e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f = List.hd args; args = [e'] }) e'))
es)
| Binop Filter, [_; EArray es] ->
EArray
(List.filter
(fun e' ->
match
evaluate_expr ctx
(Marked.same_mark_as (EApp { f = List.hd args; args = [e'] }) e')
with
| ELit (LBool b), _ -> b
| _ ->
Errors.raise_spanned_error
(Expr.pos (List.nth args 0))
"This predicate evaluated to something else than a boolean \
(should not happen if the term was well-typed)")
es)
| Binop _, ([ELit LEmptyError; _] | [_; ELit LEmptyError]) -> ELit LEmptyError
| Unop (Minus KInt), [ELit (LInt i)] ->
ELit (LInt Runtime.(integer_of_int 0 -! i))
| Unop (Minus KRat), [ELit (LRat i)] ->
ELit (LRat Runtime.(decimal_of_string "0" -& i))
| Unop (Minus KMoney), [ELit (LMoney i)] ->
ELit (LMoney Runtime.(money_of_units_int 0 -$ i))
| Unop (Minus KDuration), [ELit (LDuration i)] ->
ELit (LDuration Runtime.(~-^i))
| Unop Not, [ELit (LBool b)] -> ELit (LBool (not b))
| Unop Length, [EArray es] ->
ELit (LInt (Runtime.integer_of_int (List.length es)))
| Unop GetDay, [ELit (LDate d)] ->
ELit (LInt Runtime.(day_of_month_of_date d))
| Unop GetMonth, [ELit (LDate d)] ->
ELit (LInt Runtime.(month_number_of_date d))
| Unop GetYear, [ELit (LDate d)] -> ELit (LInt Runtime.(year_of_date d))
| Unop FirstDayOfMonth, [ELit (LDate d)] ->
ELit (LDate Runtime.(first_day_of_month d))
| Unop LastDayOfMonth, [ELit (LDate d)] ->
ELit (LDate Runtime.(first_day_of_month d))
| Unop IntToRat, [ELit (LInt i)] -> ELit (LRat Runtime.(decimal_of_integer i))
| Unop MoneyToRat, [ELit (LMoney i)] ->
ELit (LRat Runtime.(decimal_of_money i))
| Unop RatToMoney, [ELit (LRat i)] ->
ELit (LMoney Runtime.(money_of_decimal i))
| Unop RoundMoney, [ELit (LMoney m)] -> ELit (LMoney Runtime.(money_round m))
| Unop RoundDecimal, [ELit (LRat m)] -> ELit (LRat Runtime.(decimal_round m))
| Unop (Log (entry, infos)), [e'] ->
if !Cli.trace_flag then (
match entry with
| VarDef _ ->
(* TODO: this usage of Format is broken, Formatting requires that all is
formatted in one pass, without going through intermediate "%s" *)
Cli.log_format "%*s%a %a: %s" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos
(match e' with
| EAbs _ -> Cli.with_style [ANSITerminal.green] "<function>"
| _ ->
let expr_str =
Format.asprintf "%a" (Expr.format ctx ~debug:false) (List.hd args)
in
let expr_str =
Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\n\\s*")
~subst:(fun _ -> " ")
expr_str
in
Cli.with_style [ANSITerminal.green] "%s" expr_str)
| PosRecordIfTrueBool -> (
match pos <> Pos.no_pos, e' with
| true, ELit (LBool true) ->
Cli.log_format "%*s%a%s:\n%s" (!log_indent * 2) "" Print.log_entry
entry
(Cli.with_style [ANSITerminal.green] "Definition applied")
(Cli.add_prefix_to_each_line (Pos.retrieve_loc_text pos) (fun _ ->
Format.asprintf "%*s" (!log_indent * 2) ""))
| _ -> ())
| BeginCall ->
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos;
log_indent := !log_indent + 1
| EndCall ->
log_indent := !log_indent - 1;
Cli.log_format "%*s%a %a" (!log_indent * 2) "" Print.log_entry entry
Print.uid_list infos)
else ();
e'
| Unop _, [ELit LEmptyError] -> ELit LEmptyError
| _ ->
let err () =
Errors.raise_multispanned_error
([Some "Operator:", pos]
@ List.mapi
@ -313,6 +149,153 @@ let rec evaluate_operator
args)
"Operator applied to the wrong arguments\n\
(should not happen if the term was well-typed)"
in
let open Runtime.Oper in
if List.exists (function ELit LEmptyError, _ -> true | _ -> false) args then
ELit LEmptyError
else
Operator.kind_dispatch op
~polymorphic:(fun op ->
match op, args with
| Length, [(EArray es, _)] ->
ELit (LInt (Runtime.integer_of_int (List.length es)))
| Log (entry, infos), [e'] ->
print_log ctx entry infos pos e';
Marked.unmark e'
| Eq, [(e1, _); (e2, _)] -> ELit (LBool (handle_eq ctx pos e1 e2))
| Map, [f; (EArray es, _)] ->
EArray
(List.map
(fun e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [e'] }) e'))
es)
| Concat, [(EArray es1, _); (EArray es2, _)] -> EArray (es1 @ es2)
| Filter, [f; (EArray es, _)] ->
EArray
(List.filter
(fun e' ->
match
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [e'] }) e')
with
| ELit (LBool b), _ -> b
| _ ->
Errors.raise_spanned_error
(Expr.pos (List.nth args 0))
"This predicate evaluated to something else than a \
boolean (should not happen if the term was well-typed)")
es)
| Fold, [f; init; (EArray es, _)] ->
Marked.unmark
(List.fold_left
(fun acc e' ->
evaluate_expr ctx
(Marked.same_mark_as (EApp { f; args = [acc; e'] }) e'))
init es)
| (Length | Log _ | Eq | Map | Concat | Filter | Fold), _ -> err ())
~monomorphic:(fun op ->
let rlit =
match op, List.map (function ELit l, _ -> l | _ -> err ()) args with
| Not, [LBool b] -> LBool (o_not b)
| IntToRat, [LInt i] -> LRat (o_intToRat i)
| MoneyToRat, [LMoney i] -> LRat (o_moneyToRat i)
| RatToMoney, [LRat i] -> LMoney (o_ratToMoney i)
| GetDay, [LDate d] -> LInt (o_getDay d)
| GetMonth, [LDate d] -> LInt (o_getMonth d)
| GetYear, [LDate d] -> LInt (o_getYear d)
| FirstDayOfMonth, [LDate d] -> LDate (o_firstDayOfMonth d)
| LastDayOfMonth, [LDate d] -> LDate (o_lastDayOfMonth d)
| RoundMoney, [LMoney m] -> LMoney (o_roundMoney m)
| RoundDecimal, [LRat m] -> LRat (o_roundDecimal m)
| And, [LBool b1; LBool b2] -> LBool (o_and b1 b2)
| Or, [LBool b1; LBool b2] -> LBool (o_or b1 b2)
| Xor, [LBool b1; LBool b2] -> LBool (o_xor b1 b2)
| ( ( Not | IntToRat | MoneyToRat | RatToMoney | GetDay | GetMonth
| GetYear | FirstDayOfMonth | LastDayOfMonth | RoundMoney
| RoundDecimal | And | Or | Xor ),
_ ) ->
err ()
in
ELit rlit)
~resolved:(fun op ->
let rlit =
match op, List.map (function ELit l, _ -> l | _ -> err ()) args with
| Minus_int, [LInt x] -> LInt (o_minus_int x)
| Minus_rat, [LRat x] -> LRat (o_minus_rat x)
| Minus_mon, [LMoney x] -> LMoney (o_minus_mon x)
| Minus_dur, [LDuration x] -> LDuration (o_minus_dur x)
| Add_int_int, [LInt x; LInt y] -> LInt (o_add_int_int x y)
| Add_rat_rat, [LRat x; LRat y] -> LRat (o_add_rat_rat x y)
| Add_mon_mon, [LMoney x; LMoney y] -> LMoney (o_add_mon_mon x y)
| Add_dat_dur, [LDate x; LDuration y] -> LDate (o_add_dat_dur x y)
| Add_dur_dur, [LDuration x; LDuration y] ->
LDuration (o_add_dur_dur x y)
| Sub_int_int, [LInt x; LInt y] -> LInt (o_sub_int_int x y)
| Sub_rat_rat, [LRat x; LRat y] -> LRat (o_sub_rat_rat x y)
| Sub_mon_mon, [LMoney x; LMoney y] -> LMoney (o_sub_mon_mon x y)
| Sub_dat_dat, [LDate x; LDate y] -> LDuration (o_sub_dat_dat x y)
| Sub_dat_dur, [LDate x; LDuration y] -> LDate (o_sub_dat_dur x y)
| Sub_dur_dur, [LDuration x; LDuration y] ->
LDuration (o_sub_dur_dur x y)
| Mult_int_int, [LInt x; LInt y] -> LInt (o_mult_int_int x y)
| Mult_rat_rat, [LRat x; LRat y] -> LRat (o_mult_rat_rat x y)
| Mult_mon_rat, [LMoney x; LRat y] -> LMoney (o_mult_mon_rat x y)
| Mult_dur_int, [LDuration x; LInt y] ->
LDuration (o_mult_dur_int x y)
| Div_int_int, [LInt x; LInt y] -> LInt (protect o_div_int_int x y)
| Div_rat_rat, [LRat x; LRat y] -> LRat (protect o_div_rat_rat x y)
| Div_mon_mon, [LMoney x; LMoney y] ->
LRat (protect o_div_mon_mon x y)
| Div_mon_rat, [LMoney x; LRat y] ->
LMoney (protect o_div_mon_rat x y)
| Lt_int_int, [LInt x; LInt y] -> LBool (o_lt_int_int x y)
| Lt_rat_rat, [LRat x; LRat y] -> LBool (o_lt_rat_rat x y)
| Lt_mon_mon, [LMoney x; LMoney y] -> LBool (o_lt_mon_mon x y)
| Lt_dat_dat, [LDate x; LDate y] -> LBool (o_lt_dat_dat x y)
| Lt_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_lt_dur_dur x y)
| Lte_int_int, [LInt x; LInt y] -> LBool (o_lte_int_int x y)
| Lte_rat_rat, [LRat x; LRat y] -> LBool (o_lte_rat_rat x y)
| Lte_mon_mon, [LMoney x; LMoney y] -> LBool (o_lte_mon_mon x y)
| Lte_dat_dat, [LDate x; LDate y] -> LBool (o_lte_dat_dat x y)
| Lte_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_lte_dur_dur x y)
| Gt_int_int, [LInt x; LInt y] -> LBool (o_gt_int_int x y)
| Gt_rat_rat, [LRat x; LRat y] -> LBool (o_gt_rat_rat x y)
| Gt_mon_mon, [LMoney x; LMoney y] -> LBool (o_gt_mon_mon x y)
| Gt_dat_dat, [LDate x; LDate y] -> LBool (o_gt_dat_dat x y)
| Gt_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_gt_dur_dur x y)
| Gte_int_int, [LInt x; LInt y] -> LBool (o_gte_int_int x y)
| Gte_rat_rat, [LRat x; LRat y] -> LBool (o_gte_rat_rat x y)
| Gte_mon_mon, [LMoney x; LMoney y] -> LBool (o_gte_mon_mon x y)
| Gte_dat_dat, [LDate x; LDate y] -> LBool (o_gte_dat_dat x y)
| Gte_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_gte_dur_dur x y)
| Eq_int_int, [LInt x; LInt y] -> LBool (o_eq_int_int x y)
| Eq_rat_rat, [LRat x; LRat y] -> LBool (o_eq_rat_rat x y)
| Eq_mon_mon, [LMoney x; LMoney y] -> LBool (o_eq_mon_mon x y)
| Eq_dat_dat, [LDate x; LDate y] -> LBool (o_eq_dat_dat x y)
| Eq_dur_dur, [LDuration x; LDuration y] ->
LBool (protect o_eq_dur_dur x y)
| ( ( Minus_int | Minus_rat | Minus_mon | Minus_dur | Add_int_int
| Add_rat_rat | Add_mon_mon | Add_dat_dur | Add_dur_dur
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat
| Sub_dat_dur | Sub_dur_dur | Mult_int_int | Mult_rat_rat
| Mult_mon_rat | Mult_dur_int | Div_int_int | Div_rat_rat
| Div_mon_mon | Div_mon_rat | Lt_int_int | Lt_rat_rat | Lt_mon_mon
| Lt_dat_dat | Lt_dur_dur | Lte_int_int | Lte_rat_rat
| Lte_mon_mon | Lte_dat_dat | Lte_dur_dur | Gt_int_int
| Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur | Gte_int_int
| Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur
),
_ ) ->
err ()
in
ELit rlit)
~overloaded:(fun _ -> assert false)
and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
match Marked.unmark e with
@ -333,7 +316,7 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
"wrong function call, expected %d arguments, got %d"
(Bindlib.mbinder_arity binder)
(List.length args)
| EOp op ->
| EOp { op; _ } ->
Marked.same_mark_as (evaluate_operator ctx op (Expr.pos e) args) e
| ELit LEmptyError -> Marked.same_mark_as (ELit LEmptyError) e
| _ ->
@ -449,31 +432,41 @@ and evaluate_expr (ctx : decl_ctx) (e : 'm Ast.expr) : 'm Ast.expr =
| EErrorOnEmpty
( EApp
{
f = EOp (Binop op), _;
f = EOp { op; _ }, _;
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
},
_ )
_ ) ->
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
(Expr.format ctx ~debug:false)
e1 Print.operator op
(Expr.format ctx ~debug:false)
e2
| EApp
{
f = EOp (Unop (Log _)), _;
f = EOp { op = Log _; _ }, _;
args =
[
( EApp
{
f = EOp (Binop op), _;
f = EOp { op; _ }, _;
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
},
_ );
];
}
} ->
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
(Expr.format ctx ~debug:false)
e1 Print.operator op
(Expr.format ctx ~debug:false)
e2
| EApp
{
f = EOp (Binop op), _;
f = EOp { op; _ }, _;
args = [((ELit _, _) as e1); ((ELit _, _) as e2)];
} ->
Errors.raise_spanned_error (Expr.pos e') "Assertion failed: %a %a %a"
(Expr.format ctx ~debug:false)
e1 Print.binop op
e1 Print.operator op
(Expr.format ctx ~debug:false)
e2
| _ ->

View File

@ -37,8 +37,12 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
| EApp
{
f =
( EOp (Unop Not), _
| ( EApp { f = EOp (Unop (Log _)), _; args = [(EOp (Unop Not), _)] },
( EOp { op = Not; _ }, _
| ( EApp
{
f = EOp { op = Log _; _ }, _;
args = [(EOp { op = Not; _ }, _)];
},
_ ) ) as op;
args = [e1];
} -> (
@ -50,8 +54,12 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
| EApp
{
f =
( EOp (Binop Or), _
| ( EApp { f = EOp (Unop (Log _)), _; args = [(EOp (Binop Or), _)] },
( EOp { op = Or; _ }, _
| ( EApp
{
f = EOp { op = Log _; _ }, _;
args = [(EOp { op = Or; _ }, _)];
},
_ ) ) as op;
args = [e1; e2];
} -> (
@ -65,8 +73,12 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
| EApp
{
f =
( EOp (Binop And), _
| ( EApp { f = EOp (Unop (Log _)), _; args = [(EOp (Binop And), _)] },
( EOp { op = And; _ }, _
| ( EApp
{
f = EOp { op = Log _; _ }, _;
args = [(EOp { op = And; _ }, _)];
},
_ ) ) as op;
args = [e1; e2];
} -> (
@ -111,15 +123,17 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
| ( [],
( ( ELit (LBool true)
| EApp
{ f = EOp (Unop (Log _)), _; args = [(ELit (LBool true), _)] }
),
{
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool true), _)];
} ),
_ ) ) ->
Marked.unmark cons
| ( [],
( ( ELit (LBool false)
| EApp
{
f = EOp (Unop (Log _)), _;
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool false), _)];
} ),
_ ) ) ->
@ -139,7 +153,10 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
cond =
( ELit (LBool true), _
| ( EApp
{ f = EOp (Unop (Log _)), _; args = [(ELit (LBool true), _)] },
{
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool true), _)];
},
_ ) );
etrue;
_;
@ -151,7 +168,7 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
( ( ELit (LBool false)
| EApp
{
f = EOp (Unop (Log _)), _;
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool false), _)];
} ),
_ );
@ -166,7 +183,7 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
( ( ELit (LBool btrue)
| EApp
{
f = EOp (Unop (Log _)), _;
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool btrue), _)];
} ),
_ );
@ -174,14 +191,18 @@ let rec partial_evaluation (ctx : partial_evaluation_ctx) (e : 'm expr) :
( ( ELit (LBool bfalse)
| EApp
{
f = EOp (Unop (Log _)), _;
f = EOp { op = Log _; _ }, _;
args = [(ELit (LBool bfalse), _)];
} ),
_ );
} ->
if btrue && not bfalse then Marked.unmark cond
else if (not btrue) && bfalse then
EApp { f = EOp (Unop Not), mark; args = [cond] }
EApp
{
f = EOp { op = Not; tys = [TLit TBool, Expr.mark_pos mark] }, mark;
args = [cond];
}
(* note: this last call eliminates the condition & might skip log calls
as well *)
else (* btrue = bfalse *) ELit (LBool btrue)

View File

@ -125,7 +125,7 @@ module Rule = struct
Expr.compare c1 c2
| n -> n)
| Some (v1, t1), Some (v2, t2) -> (
match Shared_ast.Expr.compare_typ t1 t2 with
match Type.compare t1 t2 with
| 0 -> (
let open Bindlib in
let b1 = unbox (bind_var v1 (Expr.Box.lift r1.rule_just)) in

View File

@ -16,6 +16,7 @@
the License. *)
open Catala_utils
module S = Surface.Ast
module SurfacePrint = Surface.Print
open Shared_ast
module Runtime = Runtime_ocaml.Runtime
@ -27,33 +28,87 @@ module Runtime = Runtime_ocaml.Runtime
(** {1 Translating expressions} *)
let translate_op_kind (k : Surface.Ast.op_kind) : desugared op_kind =
match k with
| Surface.Ast.KInt -> KInt
| Surface.Ast.KDec -> KRat
| Surface.Ast.KMoney -> KMoney
| Surface.Ast.KDate -> KDate
| Surface.Ast.KDuration -> KDuration
(* Resolves the operator kinds into the expected operator operand types *)
let translate_binop (op : Surface.Ast.binop) : desugared binop =
let translate_binop : Surface.Ast.binop -> Pos.t -> Ast.expr boxed =
fun op pos ->
let e op tys =
Expr.eop op (List.map (Marked.mark pos) tys) (Untyped { pos })
in
match op with
| And -> And
| Or -> Or
| Xor -> Xor
| Add l -> Add (translate_op_kind l)
| Sub l -> Sub (translate_op_kind l)
| Mult l -> Mult (translate_op_kind l)
| Div l -> Div (translate_op_kind l)
| Lt l -> Lt (translate_op_kind l)
| Lte l -> Lte (translate_op_kind l)
| Gt l -> Gt (translate_op_kind l)
| Gte l -> Gte (translate_op_kind l)
| Eq -> Eq
| Neq -> Neq
| Concat -> Concat
| S.And -> e And [TLit TBool; TLit TBool]
| S.Or -> e Or [TLit TBool; TLit TBool]
| S.Xor -> e Xor [TLit TBool; TLit TBool]
| S.Add k ->
e Add
(match k with
| S.KPoly -> [TAny; TAny]
| S.KInt -> [TLit TInt; TLit TInt]
| S.KDec -> [TLit TRat; TLit TRat]
| S.KMoney -> [TLit TMoney; TLit TMoney]
| S.KDate -> [TLit TDate; TLit TDuration]
| S.KDuration -> [TLit TDuration; TLit TDuration])
| S.Sub k ->
e Sub
(match k with
| S.KPoly -> [TAny; TAny]
| S.KInt -> [TLit TInt; TLit TInt]
| S.KDec -> [TLit TRat; TLit TRat]
| S.KMoney -> [TLit TMoney; TLit TMoney]
| S.KDate -> [TLit TDate; TLit TDate]
| S.KDuration -> [TLit TDuration; TLit TDuration])
| S.Mult k ->
e Mult
(match k with
| S.KPoly -> [TAny; TAny]
| S.KInt -> [TLit TInt; TLit TInt]
| S.KDec -> [TLit TRat; TLit TRat]
| S.KMoney -> [TLit TMoney; TLit TRat]
| S.KDate -> Errors.raise_spanned_error pos "Invalid operator"
| S.KDuration -> [TLit TDuration; TLit TInt])
| S.Div k ->
e Div
(match k with
| S.KPoly -> [TAny; TAny]
| S.KInt -> [TLit TInt; TLit TInt]
| S.KDec -> [TLit TRat; TLit TRat]
| S.KMoney -> [TLit TMoney; TLit TMoney]
| S.KDate -> Errors.raise_spanned_error pos "Invalid operator"
| S.KDuration -> [TLit TDuration; TLit TDuration])
| S.Lt k | S.Lte k | S.Gt k | S.Gte k ->
e
(match op with
| S.Lt _ -> Lt
| S.Lte _ -> Lte
| S.Gt _ -> Gt
| S.Gte _ -> Gte
| _ -> assert false)
(match k with
| S.KPoly -> [TAny; TAny]
| S.KInt -> [TLit TInt; TLit TInt]
| S.KDec -> [TLit TRat; TLit TRat]
| S.KMoney -> [TLit TMoney; TLit TMoney]
| S.KDate -> [TLit TDate; TLit TDate]
| S.KDuration -> [TLit TDuration; TLit TDuration])
| S.Eq ->
e Eq [TAny; TAny]
(* This is a truly polymorphic operator, not an overload *)
| S.Neq -> assert false (* desugared already *)
| S.Concat -> e Concat [TArray (TAny, pos); TArray (TAny, pos)]
let translate_unop (op : Surface.Ast.unop) : desugared unop =
match op with Not -> Not | Minus l -> Minus (translate_op_kind l)
let translate_unop (op : Surface.Ast.unop) pos : Ast.expr boxed =
let e op ty = Expr.eop op [Marked.mark pos ty] (Untyped { pos }) in
match op with
| S.Not -> e Not (TLit TBool)
| S.Minus k ->
e Minus
(match k with
| S.KPoly -> TAny
| S.KInt -> TLit TInt
| S.KDec -> TLit TRat
| S.KMoney -> TLit TMoney
| S.KDate -> Errors.raise_spanned_error pos "Invalid operator"
| S.KDuration -> TLit TDuration)
let disambiguate_constructor
(ctxt : Name_resolution.context)
@ -102,6 +157,21 @@ let disambiguate_constructor
Errors.raise_spanned_error (Marked.get_mark enum)
"Enum %s has not been defined before" (Marked.unmark enum))
let int100 = Runtime.integer_of_int 100
let rat100 = Runtime.decimal_of_integer int100
let aggregate_typ pos = function
| None -> TAny
| Some S.Integer -> TLit TInt
| Some S.Decimal -> TLit TRat
| Some S.Money -> TLit TMoney
| Some S.Duration -> TLit TDuration
| Some S.Date -> TLit TDate
| Some pred_typ ->
Errors.raise_spanned_error pos
"It is impossible to compute this aggregation of two values of type %a"
SurfacePrint.format_primitive_typ pred_typ
(** Usage: [translate_expr scope ctxt naked_expr]
Translates [expr] into its desugared equivalent. [scope] is used to
@ -148,30 +218,36 @@ let rec translate_expr
| IfThenElse (e_if, e_then, e_else) ->
Expr.eifthenelse (rec_helper e_if) (rec_helper e_then) (rec_helper e_else)
emark
| Binop ((S.Neq, posn), e1, e2) ->
(* Neq is just sugar *)
rec_helper (Unop ((S.Not, posn), (Binop ((S.Eq, posn), e1, e2), posn)), pos)
| Binop ((op, pos), e1, e2) ->
let op_term = Expr.eop (Binop (translate_binop op)) (Untyped { pos }) in
let op_term = translate_binop op pos in
Expr.eapp op_term [rec_helper e1; rec_helper e2] emark
| Unop ((op, pos), e) ->
let op_term = Expr.eop (Unop (translate_unop op)) (Untyped { pos }) in
let op_term = translate_unop op pos in
Expr.eapp op_term [rec_helper e] emark
| Literal l ->
let lit =
match l with
| LNumber ((Int i, _), None) -> LInt (Runtime.integer_of_string i)
| LNumber ((Int i, _), Some (Percent, _)) ->
LRat Runtime.(decimal_of_string i /& decimal_of_string "100")
LRat Runtime.(Oper.o_div_rat_rat (decimal_of_string i) rat100)
| LNumber ((Dec (i, f), _), None) ->
LRat Runtime.(decimal_of_string (i ^ "." ^ f))
| LNumber ((Dec (i, f), _), Some (Percent, _)) ->
LRat
Runtime.(decimal_of_string (i ^ "." ^ f) /& decimal_of_string "100")
Runtime.(Oper.o_div_rat_rat (decimal_of_string (i ^ "." ^ f)) rat100)
| LBool b -> LBool b
| LMoneyAmount i ->
LMoney
Runtime.(
money_of_cents_integer
((integer_of_string i.money_amount_units *! integer_of_int 100)
+! integer_of_string i.money_amount_cents))
(Oper.o_add_int_int
(Oper.o_mult_int_int
(integer_of_string i.money_amount_units)
int100)
(integer_of_string i.money_amount_cents)))
| LNumber ((Int i, _), Some (Year, _)) ->
LDuration (Runtime.duration_of_numbers (int_of_string i) 0 0)
| LNumber ((Int i, _), Some (Month, _)) ->
@ -468,9 +544,10 @@ let rec translate_expr
Expr.eapp
(Expr.eop
(match op' with
| Surface.Ast.Map -> Binop Map
| Surface.Ast.Filter -> Binop Filter
| Surface.Ast.Map -> Map
| Surface.Ast.Filter -> Filter
| _ -> assert false (* should not happen *))
[TAny, pos; TAny, pos]
emark)
[f_pred; collection] emark
| CollectionOp
@ -485,20 +562,8 @@ let rec translate_expr
let ctxt, param =
Name_resolution.add_def_local_var ctxt (Marked.unmark param')
in
let op_kind =
match pred_typ with
| Surface.Ast.Integer -> KInt
| Surface.Ast.Decimal -> KRat
| Surface.Ast.Money -> KMoney
| Surface.Ast.Duration -> KDuration
| Surface.Ast.Date -> KDate
| _ ->
Errors.raise_spanned_error pos
"It is impossible to compute the arg-%s of two values of type %a"
(if max_or_min then "max" else "min")
SurfacePrint.format_primitive_typ pred_typ
in
let cmp_op = if max_or_min then Gt op_kind else Lt op_kind in
let op_ty = aggregate_typ pos pred_typ in
let cmp_op = if max_or_min then Op.Gt else Op.Lt in
let f_pred =
Expr.make_abs [| param |]
(translate_expr scope inside_definition_of ctxt predicate)
@ -512,7 +577,9 @@ let rec translate_expr
let fold_body =
Expr.eifthenelse
(Expr.eapp
(Expr.eop (Binop cmp_op) (Untyped { pos = pos_op' }))
(Expr.eop cmp_op
[op_ty, pos_op'; op_ty, pos_op']
(Untyped { pos = pos_op' }))
[
Expr.eapp f_pred [acc_var_e] emark;
Expr.eapp f_pred [item_var_e] emark;
@ -523,7 +590,9 @@ let rec translate_expr
let fold_f =
Expr.make_abs [| acc_var; item_var |] fold_body [TAny, pos; TAny, pos] pos
in
Expr.eapp (Expr.eop (Ternop Fold) emark) [fold_f; init; collection] emark
Expr.eapp
(Expr.eop Fold [TAny, pos_op'; TAny, pos_op'; TAny, pos_op'] emark)
[fold_f; init; collection] emark
| CollectionOp (op', param', collection, predicate) ->
let ctxt, param =
Name_resolution.add_def_local_var ctxt (Marked.unmark param')
@ -561,20 +630,22 @@ let rec translate_expr
Expr.make_var acc_var (Untyped { pos = Marked.get_mark param' })
in
let f_body =
let make_body (op : desugared binop) =
Expr.eapp (Expr.eop (Binop op) mark)
let make_body op =
Expr.eapp (translate_binop op pos)
[acc; translate_expr scope inside_definition_of ctxt predicate]
emark
in
let make_extr_body (cmp_op : desugared binop) (t : typ) =
let make_extr_body cmp_op typ =
let tmp_var = Var.make "tmp" in
let tmp =
Expr.make_var tmp_var (Untyped { pos = Marked.get_mark param' })
in
Expr.make_let_in tmp_var t
Expr.make_let_in tmp_var (TAny, pos)
(translate_expr scope inside_definition_of ctxt predicate)
(Expr.eifthenelse
(Expr.eapp (Expr.eop (Binop cmp_op) mark) [acc; tmp] emark)
(Expr.eapp
(Expr.eop cmp_op [typ, pos; typ, pos] mark)
[acc; tmp] emark)
acc tmp emark)
pos
in
@ -587,7 +658,7 @@ let rec translate_expr
| Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Integer) ->
make_body (Add KInt)
| Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Decimal) ->
make_body (Add KRat)
make_body (Add KDec)
| Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Money) ->
make_body (Add KMoney)
| Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Duration) ->
@ -596,20 +667,8 @@ let rec translate_expr
assert false (* should not happen *)
| Surface.Ast.Aggregate (Surface.Ast.AggregateExtremum (max_or_min, t, _))
->
let op_kind, typ =
match t with
| Surface.Ast.Integer -> KInt, (TLit TInt, pos)
| Surface.Ast.Decimal -> KRat, (TLit TRat, pos)
| Surface.Ast.Money -> KMoney, (TLit TMoney, pos)
| Surface.Ast.Duration -> KDuration, (TLit TDuration, pos)
| Surface.Ast.Date -> KDate, (TLit TDate, pos)
| _ ->
Errors.raise_spanned_error pos
"It is impossible to compute the %s of two values of type %a"
(if max_or_min then "max" else "min")
SurfacePrint.format_primitive_typ t
in
let cmp_op = if max_or_min then Gt op_kind else Lt op_kind in
let typ = aggregate_typ pos t in
let cmp_op = if max_or_min then Op.Gt else Op.Lt in
make_extr_body cmp_op typ
| Surface.Ast.Aggregate Surface.Ast.AggregateCount ->
let predicate =
@ -617,7 +676,7 @@ let rec translate_expr
in
Expr.eifthenelse predicate
(Expr.eapp
(Expr.eop (Binop (Add KInt)) mark)
(Expr.eop Add [TLit TInt, pos; TLit TInt, pos] mark)
[
acc;
Expr.elit
@ -628,11 +687,11 @@ let rec translate_expr
acc emark
in
let f =
let make_f (t : typ_lit) =
let make_f t =
Expr.eabs
(Expr.bind [| acc_var; param |] f_body)
[
TLit t, Marked.get_mark op';
t, Marked.get_mark op';
TAny, pos
(* we put any here because the type of the elements of the arrays is
not always the type of the accumulator; for instance in
@ -644,30 +703,17 @@ let rec translate_expr
| Surface.Ast.Map | Surface.Ast.Filter
| Surface.Ast.Aggregate (Surface.Ast.AggregateArgExtremum _) ->
assert false (* should not happen *)
| Surface.Ast.Exists -> make_f TBool
| Surface.Ast.Forall -> make_f TBool
| Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Integer)
| Surface.Ast.Aggregate
(Surface.Ast.AggregateExtremum (_, Surface.Ast.Integer, _)) ->
make_f TInt
| Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Decimal)
| Surface.Ast.Aggregate
(Surface.Ast.AggregateExtremum (_, Surface.Ast.Decimal, _)) ->
make_f TRat
| Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Money)
| Surface.Ast.Aggregate
(Surface.Ast.AggregateExtremum (_, Surface.Ast.Money, _)) ->
make_f TMoney
| Surface.Ast.Aggregate (Surface.Ast.AggregateSum Surface.Ast.Duration)
| Surface.Ast.Aggregate
(Surface.Ast.AggregateExtremum (_, Surface.Ast.Duration, _)) ->
make_f TDuration
| Surface.Ast.Aggregate (Surface.Ast.AggregateSum _)
| Surface.Ast.Aggregate (Surface.Ast.AggregateExtremum _) ->
assert false (* should not happen *)
| Surface.Ast.Aggregate Surface.Ast.AggregateCount -> make_f TInt
| Surface.Ast.Exists -> make_f (TLit TBool)
| Surface.Ast.Forall -> make_f (TLit TBool)
| Surface.Ast.Aggregate (Surface.Ast.AggregateSum k) ->
make_f (aggregate_typ pos (Some k))
| Surface.Ast.Aggregate (Surface.Ast.AggregateExtremum (_, k, _)) ->
make_f (aggregate_typ pos k)
| Surface.Ast.Aggregate Surface.Ast.AggregateCount -> make_f (TLit TInt)
in
Expr.eapp (Expr.eop (Ternop Fold) emark) [f; init; collection] emark
Expr.eapp
(Expr.eop Fold [TAny, pos; TAny, pos; TAny, pos] mark)
[f; init; collection] emark
| MemCollection (member, collection) ->
let param_var = Var.make "collection_member" in
let param = Expr.make_var param_var emark in
@ -678,8 +724,13 @@ let rec translate_expr
let f_body =
let member = translate_expr scope inside_definition_of ctxt member in
Expr.eapp
(Expr.eop (Binop Or) emark)
[Expr.eapp (Expr.eop (Binop Eq) emark) [member; param] emark; acc]
(Expr.eop Or [TLit TBool, pos; TLit TBool, pos] emark)
[
Expr.eapp
(Expr.eop Eq [TAny, pos; TAny, pos] emark)
[member; param] emark;
acc;
]
emark
in
let f =
@ -688,18 +739,20 @@ let rec translate_expr
[TLit TBool, pos; TAny, pos]
emark
in
Expr.eapp (Expr.eop (Ternop Fold) emark) [f; init; collection] emark
| Builtin IntToDec -> Expr.eop (Unop IntToRat) emark
| Builtin MoneyToDec -> Expr.eop (Unop MoneyToRat) emark
| Builtin DecToMoney -> Expr.eop (Unop RatToMoney) emark
| Builtin Cardinal -> Expr.eop (Unop Length) emark
| Builtin GetDay -> Expr.eop (Unop GetDay) emark
| Builtin GetMonth -> Expr.eop (Unop GetMonth) emark
| Builtin GetYear -> Expr.eop (Unop GetYear) emark
| Builtin FirstDayOfMonth -> Expr.eop (Unop FirstDayOfMonth) emark
| Builtin LastDayOfMonth -> Expr.eop (Unop LastDayOfMonth) emark
| Builtin RoundMoney -> Expr.eop (Unop RoundMoney) emark
| Builtin RoundDecimal -> Expr.eop (Unop RoundDecimal) emark
Expr.eapp
(Expr.eop Fold [TAny, pos; TAny, pos; TAny, pos] emark)
[f; init; collection] emark
| Builtin IntToDec -> Expr.eop IntToRat [TLit TInt, pos] emark
| Builtin MoneyToDec -> Expr.eop MoneyToRat [TLit TMoney, pos] emark
| Builtin DecToMoney -> Expr.eop RatToMoney [TLit TRat, pos] emark
| Builtin Cardinal -> Expr.eop Length [TArray (TAny, pos), pos] emark
| Builtin GetDay -> Expr.eop GetDay [TLit TDate, pos] emark
| Builtin GetMonth -> Expr.eop GetMonth [TLit TDate, pos] emark
| Builtin GetYear -> Expr.eop GetYear [TLit TDate, pos] emark
| Builtin FirstDayOfMonth -> Expr.eop FirstDayOfMonth [TLit TDate, pos] emark
| Builtin LastDayOfMonth -> Expr.eop LastDayOfMonth [TLit TDate, pos] emark
| Builtin RoundMoney -> Expr.eop RoundMoney [TLit TMoney, pos] emark
| Builtin RoundDecimal -> Expr.eop RoundDecimal [TLit TRat, pos] emark
and disambiguate_match_and_build_expression
(scope : ScopeName.t)
@ -844,7 +897,11 @@ let merge_conditions
(default_pos : Pos.t) : Ast.expr boxed =
match precond, cond with
| Some precond, Some cond ->
let op_term = Expr.eop (Binop And) (Marked.get_mark cond) in
let op_term =
Expr.eop And
[TLit TBool, default_pos; TLit TBool, default_pos]
(Marked.get_mark cond)
in
Expr.eapp op_term [precond; cond] (Marked.get_mark cond)
| Some precond, None -> Marked.unmark precond, Untyped { pos = default_pos }
| None, Some cond -> cond

View File

@ -72,7 +72,7 @@ and translate_expr (ctx : 'm ctx) (e : 'm D.expr) : 'm A.expr boxed =
l) ->
Expr.elit l m
| ELit LEmptyError -> Expr.eraise EmptyError m
| EOp op -> Expr.eop (Expr.translate_op op) m
| EOp { op; tys } -> Expr.eop (Operator.translate op) tys m
| EIfThenElse { cond; etrue; efalse } ->
Expr.eifthenelse (translate_expr ctx cond) (translate_expr ctx etrue)
(translate_expr ctx efalse)

View File

@ -289,7 +289,7 @@ let rec translate_and_hoist (ctx : 'm ctx) (e : 'm D.expr) :
let es', hoists = es |> List.map (translate_and_hoist ctx) |> List.split in
Expr.earray es' mark, disjoint_union_maps (Expr.pos e) hoists
| EOp op -> Expr.eop (Expr.translate_op op) mark, Var.Map.empty
| EOp { op; tys } -> Expr.eop (Operator.translate op) tys mark, Var.Map.empty
and translate_expr ?(append_esome = true) (ctx : 'm ctx) (e : 'm D.expr) :
'm A.expr boxed =

View File

@ -73,10 +73,12 @@ let rec peephole_expr (e : 'm expr) : 'm expr boxed =
(fun cond etrue efalse ->
match Marked.unmark cond with
| ELit (LBool true)
| EApp { f = EOp (Unop (Log _)), _; args = [(ELit (LBool true), _)] } ->
| EApp { f = EOp { op = Log _; _ }, _; args = [(ELit (LBool true), _)] }
->
Marked.unmark etrue
| ELit (LBool false)
| EApp { f = EOp (Unop (Log _)), _; args = [(ELit (LBool false), _)] }
| EApp
{ f = EOp { op = Log _; _ }, _; args = [(ELit (LBool false), _)] }
->
Marked.unmark efalse
| _ -> EIfThenElse { cond; etrue; efalse })

View File

@ -54,36 +54,6 @@ let format_lit (fmt : Format.formatter) (l : lit Marked.pos) : unit =
let years, months, days = Runtime.duration_to_years_months_days d in
Format.fprintf fmt "duration_of_numbers (%d) (%d) (%d)" years months days
let format_op_kind (fmt : Format.formatter) (k : 'a op_kind) =
Format.fprintf fmt "%s"
(match k with
| KInt -> "!"
| KRat -> "&"
| KMoney -> "$"
| KDate -> "@"
| KDuration -> "^")
let format_binop (fmt : Format.formatter) (op : 'a binop Marked.pos) : unit =
match Marked.unmark op with
| Add k -> Format.fprintf fmt "+%a" format_op_kind k
| Sub k -> Format.fprintf fmt "-%a" format_op_kind k
| Mult k -> Format.fprintf fmt "*%a" format_op_kind k
| Div k -> Format.fprintf fmt "/%a" format_op_kind k
| And -> Format.fprintf fmt "%s" "&&"
| Or -> Format.fprintf fmt "%s" "||"
| Eq -> Format.fprintf fmt "%s" "="
| Neq | Xor -> Format.fprintf fmt "%s" "<>"
| Lt k -> Format.fprintf fmt "%s%a" "<" format_op_kind k
| Lte k -> Format.fprintf fmt "%s%a" "<=" format_op_kind k
| Gt k -> Format.fprintf fmt "%s%a" ">" format_op_kind k
| Gte k -> Format.fprintf fmt "%s%a" ">=" format_op_kind k
| Concat -> Format.fprintf fmt "@"
| Map -> Format.fprintf fmt "Array.map"
| Filter -> Format.fprintf fmt "array_filter"
let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) : unit =
match Marked.unmark op with Fold -> Format.fprintf fmt "Array.fold_left"
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
: unit =
Format.fprintf fmt "@[<hov 2>[%a]@]"
@ -103,26 +73,6 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
uids
let format_unop (fmt : Format.formatter) (op : lcalc unop Marked.pos) : unit =
match Marked.unmark op with
| Minus k -> Format.fprintf fmt "~-%a" format_op_kind k
| Not -> Format.fprintf fmt "%s" "not"
| Log (_entry, _infos) ->
Errors.raise_spanned_error (Marked.get_mark op)
"Internal error: a log operator has not been caught by the expression \
match"
| Length -> Format.fprintf fmt "%s" "array_length"
| IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer"
| MoneyToRat -> Format.fprintf fmt "%s" "decimal_of_money"
| RatToMoney -> Format.fprintf fmt "%s" "money_of_decimal"
| GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date"
| GetMonth -> Format.fprintf fmt "%s" "month_number_of_date"
| GetYear -> Format.fprintf fmt "%s" "year_of_date"
| FirstDayOfMonth -> Format.fprintf fmt "%s" "first_day_of_month"
| LastDayOfMonth -> Format.fprintf fmt "%s" "last_day_of_month"
| RoundMoney -> Format.fprintf fmt "%s" "money_round"
| RoundDecimal -> Format.fprintf fmt "%s" "decimal_round"
let avoid_keywords (s : string) : string =
match s with
(* list taken from
@ -134,7 +84,7 @@ let avoid_keywords (s : string) : string =
| "match" | "method" | "mod" | "module" | "mutable" | "new" | "nonrec"
| "object" | "of" | "open" | "or" | "private" | "rec" | "sig" | "struct"
| "then" | "to" | "true" | "try" | "type" | "val" | "virtual" | "when"
| "while" | "with" ->
| "while" | "with" | "Stdlib" | "Runtime" | "Oper" ->
s ^ "_user"
| _ -> s
@ -235,8 +185,8 @@ let format_var (fmt : Format.formatter) (v : 'm Var.t) : unit =
if
List.mem lowercase_name ["handle_default"; "handle_default_opt"]
|| String.begins_with_uppercase (Bindlib.name_of v)
then Format.fprintf fmt "%s" lowercase_name
else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name
then Format.pp_print_string fmt lowercase_name
else if lowercase_name = "_" then Format.pp_print_string fmt lowercase_name
else (
Cli.debug_print "lowercase_name: %s " lowercase_name;
Format.fprintf fmt "%s_" lowercase_name)
@ -305,7 +255,8 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
Format.fprintf fmt "let@ %a@ = %a@ in@ x"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(fun fmt i -> Format.fprintf fmt "%s" (if i = index then "x" else "_")))
(fun fmt i ->
Format.pp_print_string fmt (if i = index then "x" else "_")))
(List.init size Fun.id) format_with_parens e
| EStructAccess { e; field; name } ->
Format.fprintf fmt "%a.%a" format_with_parens e format_struct_field_name
@ -355,25 +306,19 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
(fun fmt (x, tau) ->
Format.fprintf fmt "@[<hov 2>(%a:@ %a)@]" format_var x format_typ tau))
xs_tau format_expr body
| EApp { f = EOp (Binop ((Map | Filter) as op)), _; args = [arg1; arg2] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_binop (op, Pos.no_pos)
format_with_parens arg1 format_with_parens arg2
| EApp { f = EOp (Binop op), _; args = [arg1; arg2] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
format_binop (op, Pos.no_pos) format_with_parens arg2
| EApp
{
f = EApp { f = EOp (Unop (Log (BeginCall, info))), _; args = [f] }, _;
f = EApp { f = EOp { op = Log (BeginCall, info); _ }, _; args = [f] }, _;
args = [arg];
}
when !Cli.trace_flag ->
Format.fprintf fmt "(log_begin_call@ %a@ %a)@ %a" format_uid_list info
format_with_parens f format_with_parens arg
| EApp { f = EOp (Unop (Log (VarDef tau, info))), _; args = [arg1] }
| EApp { f = EOp { op = Log (VarDef tau, info); _ }, _; args = [arg1] }
when !Cli.trace_flag ->
Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list
info typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1
| EApp { f = EOp (Unop (Log (PosRecordIfTrueBool, _))), m; args = [arg1] }
| EApp { f = EOp { op = Log (PosRecordIfTrueBool, _); _ }, m; args = [arg1] }
when !Cli.trace_flag ->
let pos = Expr.mark_pos m in
Format.fprintf fmt
@ -382,15 +327,12 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) format_with_parens arg1
| EApp { f = EOp (Unop (Log (EndCall, info))), _; args = [arg1] }
| EApp { f = EOp { op = Log (EndCall, info); _ }, _; args = [arg1] }
when !Cli.trace_flag ->
Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info
format_with_parens arg1
| EApp { f = EOp (Unop (Log _)), _; args = [arg1] } ->
| EApp { f = EOp { op = Log _; _ }, _; args = [arg1] } ->
Format.fprintf fmt "%a" format_with_parens arg1
| EApp { f = EOp (Unop op), _; args = [arg1] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_unop (op, Pos.no_pos)
format_with_parens arg1
| EApp { f = EVar x, pos; args }
when Var.compare x (Var.translate Ast.handle_default) = 0
|| Var.compare x (Var.translate Ast.handle_default_opt) = 0 ->
@ -419,9 +361,7 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) :
Format.fprintf fmt
"@[<hov 2> if@ @[<hov 2>%a@]@ then@ @[<hov 2>%a@]@ else@ @[<hov 2>%a@]@]"
format_with_parens cond format_with_parens etrue format_with_parens efalse
| EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos)
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
| EOp { op; _ } -> Format.pp_print_string fmt (Operator.name op)
| EAssert e' ->
Format.fprintf fmt
"@[<hov 2>if@ %a@ then@ ()@ else@ raise (AssertionFailed @[<hov \

View File

@ -28,15 +28,15 @@ let handle_default_opt = TopLevelName.fresh ("handle_default_opt", Pos.no_pos)
type expr = naked_expr Marked.pos
and naked_expr =
| EVar of LocalName.t
| EFunc of TopLevelName.t
| EStruct of expr list * StructName.t
| EStructFieldAccess of expr * StructField.t * StructName.t
| EInj of expr * EnumConstructor.t * EnumName.t
| EArray of expr list
| ELit of L.lit
| EApp of expr * expr list
| EOp of lcalc operator
| EVar : LocalName.t -> naked_expr
| EFunc : TopLevelName.t -> naked_expr
| EStruct : expr list * StructName.t -> naked_expr
| EStructFieldAccess : expr * StructField.t * StructName.t -> naked_expr
| EInj : expr * EnumConstructor.t * EnumName.t -> naked_expr
| EArray : expr list -> naked_expr
| ELit : L.lit -> naked_expr
| EApp : expr * expr list -> naked_expr
| EOp : (lcalc, _) operator -> naked_expr
type stmt =
| SInnerFuncDef of LocalName.t Marked.pos * func

View File

@ -86,7 +86,7 @@ let rec translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : A.block * A.expr =
in
let new_args = List.rev new_args in
args_stmts, (A.EArray new_args, Expr.pos expr)
| EOp op -> [], (A.EOp op, Expr.pos expr)
| EOp { op; _ } -> [], (A.EOp op, Expr.pos expr)
| ELit l -> [], (A.ELit l, Expr.pos expr)
| _ ->
let tmp_var =

View File

@ -64,25 +64,24 @@ let rec format_expr
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.enum_constructor cons
format_expr e
| ELit l -> Print.lit fmt l
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Print.binop op format_with_parens
arg1 format_with_parens arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
| EApp ((EOp ((Map | Filter) as op), _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" Print.operator op
format_with_parens arg1 format_with_parens arg2
| EApp ((EOp op, _), [arg1; arg2]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" format_with_parens arg1
Print.binop op format_with_parens arg2
| EApp ((EOp (Unop (Log _)), _), [arg1]) when not debug ->
Print.operator op format_with_parens arg2
| EApp ((EOp (Log _), _), [arg1]) when not debug ->
Format.fprintf fmt "%a" format_with_parens arg1
| EApp ((EOp (Unop op), _), [arg1]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.unop op format_with_parens arg1
| EApp ((EOp op, _), [arg1]) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" Print.operator op format_with_parens
arg1
| EApp (f, args) ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" format_expr f
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ")
format_with_parens)
args
| EOp (Ternop op) -> Format.fprintf fmt "%a" Print.ternop op
| EOp (Binop op) -> Format.fprintf fmt "%a" Print.binop op
| EOp (Unop op) -> Format.fprintf fmt "%a" Print.unop op
| EOp op -> Format.fprintf fmt "%a" Print.operator op
let rec format_statement
(decl_ctx : decl_ctx)

View File

@ -24,11 +24,11 @@ module L = Lcalc.Ast
let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit =
match Marked.unmark l with
| LBool true -> Format.fprintf fmt "True"
| LBool false -> Format.fprintf fmt "False"
| LBool true -> Format.pp_print_string fmt "True"
| LBool false -> Format.pp_print_string fmt "False"
| LInt i ->
Format.fprintf fmt "integer_of_string(\"%s\")" (Runtime.integer_to_string i)
| LUnit -> Format.fprintf fmt "Unit()"
| LUnit -> Format.pp_print_string fmt "Unit()"
| LRat i -> Format.fprintf fmt "decimal_of_string(\"%a\")" Print.lit (LRat i)
| LMoney e ->
Format.fprintf fmt "money_of_cents_string(\"%s\")"
@ -44,31 +44,59 @@ let format_lit (fmt : Format.formatter) (l : L.lit Marked.pos) : unit =
let format_log_entry (fmt : Format.formatter) (entry : log_entry) : unit =
match entry with
| VarDef _ -> Format.fprintf fmt ":="
| BeginCall -> Format.fprintf fmt ""
| VarDef _ -> Format.pp_print_string fmt ":="
| BeginCall -> Format.pp_print_string fmt ""
| EndCall -> Format.fprintf fmt "%s" ""
| PosRecordIfTrueBool -> Format.fprintf fmt ""
| PosRecordIfTrueBool -> Format.pp_print_string fmt ""
let format_binop (fmt : Format.formatter) (op : lcalc binop Marked.pos) : unit =
let format_op
(type k)
(fmt : Format.formatter)
(op : (lcalc, k) operator Marked.pos) : unit =
match Marked.unmark op with
| Add _ | Concat -> Format.fprintf fmt "+"
| Sub _ -> Format.fprintf fmt "-"
| Mult _ -> Format.fprintf fmt "*"
| Div KInt -> Format.fprintf fmt "//"
| Div _ -> Format.fprintf fmt "/"
| And -> Format.fprintf fmt "and"
| Or -> Format.fprintf fmt "or"
| Eq -> Format.fprintf fmt "=="
| Neq | Xor -> Format.fprintf fmt "!="
| Lt _ -> Format.fprintf fmt "<"
| Lte _ -> Format.fprintf fmt "<="
| Gt _ -> Format.fprintf fmt ">"
| Gte _ -> Format.fprintf fmt ">="
| Map -> Format.fprintf fmt "list_map"
| Filter -> Format.fprintf fmt "list_filter"
let format_ternop (fmt : Format.formatter) (op : ternop Marked.pos) : unit =
match Marked.unmark op with Fold -> Format.fprintf fmt "list_fold_left"
| Log (entry, infos) -> assert false
| Minus_int | Minus_rat | Minus_mon | Minus_dur ->
Format.pp_print_string fmt "-"
(* Todo: use the names from [Operator.name] *)
| Not -> Format.pp_print_string fmt "not"
| Length -> Format.pp_print_string fmt "list_length"
| IntToRat -> Format.pp_print_string fmt "decimal_of_integer"
| MoneyToRat -> Format.pp_print_string fmt "decimal_of_money"
| RatToMoney -> Format.pp_print_string fmt "money_of_decimal"
| GetDay -> Format.pp_print_string fmt "day_of_month_of_date"
| GetMonth -> Format.pp_print_string fmt "month_number_of_date"
| GetYear -> Format.pp_print_string fmt "year_of_date"
| FirstDayOfMonth -> Format.pp_print_string fmt "first_day_of_month"
| LastDayOfMonth -> Format.pp_print_string fmt "last_day_of_month"
| RoundMoney -> Format.pp_print_string fmt "money_round"
| RoundDecimal -> Format.pp_print_string fmt "decimal_round"
| Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur | Add_dur_dur | Concat
->
Format.pp_print_string fmt "+"
| Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat | Sub_dat_dur
| Sub_dur_dur ->
Format.pp_print_string fmt "-"
| Mult_int_int | Mult_rat_rat | Mult_mon_rat | Mult_dur_int ->
Format.pp_print_string fmt "*"
| Div_int_int -> Format.pp_print_string fmt "//"
| Div_rat_rat | Div_mon_mon | Div_mon_rat -> Format.pp_print_string fmt "/"
| And -> Format.pp_print_string fmt "and"
| Or -> Format.pp_print_string fmt "or"
| Eq -> Format.pp_print_string fmt "=="
| Xor -> Format.pp_print_string fmt "!="
| Lt_int_int | Lt_rat_rat | Lt_mon_mon | Lt_dat_dat | Lt_dur_dur ->
Format.pp_print_string fmt "<"
| Lte_int_int | Lte_rat_rat | Lte_mon_mon | Lte_dat_dat | Lte_dur_dur ->
Format.pp_print_string fmt "<="
| Gt_int_int | Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur ->
Format.pp_print_string fmt ">"
| Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur ->
Format.pp_print_string fmt ">="
| Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur ->
Format.pp_print_string fmt "=="
| Map -> Format.pp_print_string fmt "list_map"
| Filter -> Format.pp_print_string fmt "list_filter"
| Fold -> Format.pp_print_string fmt "list_fold_left"
let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list)
: unit =
@ -89,23 +117,6 @@ let format_string_list (fmt : Format.formatter) (uids : string list) : unit =
(Re.replace sanitize_quotes ~f:(fun _ -> "\\\"") info)))
uids
let format_unop (fmt : Format.formatter) (op : lcalc unop Marked.pos) : unit =
match Marked.unmark op with
| Minus _ -> Format.fprintf fmt "-"
| Not -> Format.fprintf fmt "not"
| Log (entry, infos) -> assert false (* should not happen *)
| Length -> Format.fprintf fmt "%s" "list_length"
| IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer"
| MoneyToRat -> Format.fprintf fmt "%s" "decimal_of_money"
| RatToMoney -> Format.fprintf fmt "%s" "money_of_decimal"
| GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date"
| GetMonth -> Format.fprintf fmt "%s" "month_number_of_date"
| GetYear -> Format.fprintf fmt "%s" "year_of_date"
| FirstDayOfMonth -> Format.fprintf fmt "%s" "first_day_of_month"
| LastDayOfMonth -> Format.fprintf fmt "%s" "last_day_of_month"
| RoundMoney -> Format.fprintf fmt "%s" "money_round"
| RoundDecimal -> Format.fprintf fmt "%s" "decimal_round"
let avoid_keywords (s : string) : string =
if
match s with
@ -298,21 +309,20 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
(fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e))
es
| ELit l -> Format.fprintf fmt "%a" format_lit (Marked.same_mark_as l e)
| EApp ((EOp (Binop ((Map | Filter) as op)), _), [arg1; arg2]) ->
Format.fprintf fmt "%a(%a,@ %a)" format_binop (op, Pos.no_pos)
| EApp ((EOp ((Map | Filter) as op), _), [arg1; arg2]) ->
Format.fprintf fmt "%a(%a,@ %a)" format_op (op, Pos.no_pos)
(format_expression ctx) arg1 (format_expression ctx) arg2
| EApp ((EOp (Binop op), _), [arg1; arg2]) ->
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_binop
| EApp ((EOp op, _), [arg1; arg2]) ->
Format.fprintf fmt "(%a %a@ %a)" (format_expression ctx) arg1 format_op
(op, Pos.no_pos) (format_expression ctx) arg2
| EApp ((EApp ((EOp (Unop (Log (BeginCall, info))), _), [f]), _), [arg])
| EApp ((EApp ((EOp (Log (BeginCall, info)), _), [f]), _), [arg])
when !Cli.trace_flag ->
Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info
(format_expression ctx) f (format_expression ctx) arg
| EApp ((EOp (Unop (Log (VarDef tau, info))), _), [arg1]) when !Cli.trace_flag
->
| EApp ((EOp (Log (VarDef tau, info)), _), [arg1]) when !Cli.trace_flag ->
Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1
| EApp ((EOp (Unop (Log (PosRecordIfTrueBool, _))), pos), [arg1])
| EApp ((EOp (Log (PosRecordIfTrueBool, _)), pos), [arg1])
when !Cli.trace_flag ->
Format.fprintf fmt
"log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \
@ -320,16 +330,21 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
(Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos)
(Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list
(Pos.get_law_info pos) (format_expression ctx) arg1
| EApp ((EOp (Unop (Log (EndCall, info))), _), [arg1]) when !Cli.trace_flag ->
| EApp ((EOp (Log (EndCall, info)), _), [arg1]) when !Cli.trace_flag ->
Format.fprintf fmt "log_end_call(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1
| EApp ((EOp (Unop (Log _)), _), [arg1]) ->
| EApp ((EOp (Log _), _), [arg1]) ->
Format.fprintf fmt "%a" (format_expression ctx) arg1
| EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [arg1]) ->
Format.fprintf fmt "%a %a" format_unop (op, Pos.no_pos)
| EApp ((EOp Not, _), [arg1]) ->
Format.fprintf fmt "%a %a" format_op (Not, Pos.no_pos)
(format_expression ctx) arg1
| EApp ((EOp (Unop op), _), [arg1]) ->
Format.fprintf fmt "%a(%a)" format_unop (op, Pos.no_pos)
| EApp
((EOp ((Minus_int | Minus_rat | Minus_mon | Minus_dur) as op), _), [arg1])
->
Format.fprintf fmt "%a %a" format_op (op, Pos.no_pos)
(format_expression ctx) arg1
| EApp ((EOp op, _), [arg1]) ->
Format.fprintf fmt "%a(%a)" format_op (op, Pos.no_pos)
(format_expression ctx) arg1
| EApp ((EFunc x, pos), args)
when Ast.TopLevelName.compare x Ast.handle_default = 0
@ -350,9 +365,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ")
(format_expression ctx))
args
| EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos)
| EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos)
| EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos)
| EOp op -> Format.fprintf fmt "%a" format_op (op, Pos.no_pos)
let rec format_statement
(ctx : decl_ctx)

View File

@ -36,7 +36,7 @@ let tag_with_log_entry
(l : log_entry)
(markings : Uid.MarkedString.info list) : untyped Ast.expr boxed =
Expr.eapp
(Expr.eop (Unop (Log (l, markings))) (Marked.get_mark e))
(Expr.eop (Log (l, markings)) [TAny, Expr.pos e] (Marked.get_mark e))
[e] (Marked.get_mark e)
let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) :
@ -128,9 +128,23 @@ let rec translate_expr (ctx : ctx) (e : Desugared.Ast.expr) :
ctx (Array.to_list vars) (Array.to_list new_vars)
in
Expr.eabs (Expr.bind new_vars (translate_expr ctx body)) tys m
| EApp { f = EOp { op; tys }, m1; args } ->
let args = List.map (translate_expr ctx) args in
Operator.kind_dispatch op
~monomorphic:(fun op -> Expr.eapp (Expr.eop op tys m1) args m)
~polymorphic:(fun op -> Expr.eapp (Expr.eop op tys m1) args m)
~overloaded:(fun op ->
match
Operator.resolve_overload ctx.decl_ctx
(Marked.mark (Expr.pos e) op)
tys
with
| op, `Straight -> Expr.eapp (Expr.eop op tys m1) args m
| op, `Reversed ->
Expr.eapp (Expr.eop op (List.rev tys) m1) (List.rev args) m)
| EOp _ -> assert false (* Only allowed within [EApp] *)
| EApp { f; args } ->
Expr.eapp (translate_expr ctx f) (List.map (translate_expr ctx) args) m
| EOp op -> Expr.eop (Expr.translate_op op) m
| EDefault { excepts; just; cons } ->
Expr.edefault
(List.map (translate_expr ctx) excepts)

View File

@ -85,7 +85,7 @@ let scope ?(debug = false) ctx fmt (name, decl) =
.io_input
with
| Reentrant ->
Format.fprintf fmt "%a@ %a" Print.operator
Format.fprintf fmt "%a@ %a" Print.op_style
"reentrant or by default" (Print.expr ~debug ctx) e
| _ -> Format.fprintf fmt "%a" (Print.expr ~debug ctx) e))
e

View File

@ -82,34 +82,6 @@ and naked_typ =
type date = Runtime.date
type duration = Runtime.duration
type 'a op_kind =
(* | Kpoly: desugared op_kind -- Coming soon ! *)
| KInt : 'a any op_kind
| KRat : 'a any op_kind
| KMoney : 'a any op_kind
| KDate : 'a any op_kind
| KDuration : 'a any op_kind (** All ops don't have a KDate and KDuration. *)
type ternop = Fold
type 'a binop =
| And
| Or
| Xor
| Add of 'a op_kind
| Sub of 'a op_kind
| Mult of 'a op_kind
| Div of 'a op_kind
| Lt of 'a op_kind
| Lte of 'a op_kind
| Gt of 'a op_kind
| Gte of 'a op_kind
| Eq
| Neq
| Map
| Concat
| Filter
type log_entry =
| VarDef of naked_typ
(** During code generation, we need to know the type of the variable being
@ -118,23 +90,131 @@ type log_entry =
| EndCall
| PosRecordIfTrueBool
type 'a unop =
| Not
| Minus of 'a op_kind
| Log of log_entry * Uid.MarkedString.info list
| Length
| IntToRat
| MoneyToRat
| RatToMoney
| GetDay
| GetMonth
| GetYear
| FirstDayOfMonth
| LastDayOfMonth
| RoundMoney
| RoundDecimal
module Op = struct
(** Classification of operators on how they should be typed *)
type 'a operator = Ternop of ternop | Binop of 'a binop | Unop of 'a unop
type monomorphic =
| Monomorphic (** Operands and return types of the operator are fixed *)
type polymorphic =
| Polymorphic
(** The operator is truly polymorphic: it's the same runtime function
that may work on multiple types. We require that resolving the
argument types from right to left trivially resolves all type
variables declared in the operator type. *)
type overloaded =
| Overloaded
(** The operator is ambiguous and requires the types of its arguments to
be known before it can be typed, using a pre-defined table *)
type resolved =
| Resolved (** Explicit monomorphic versions of the overloaded operators *)
(** Classification of operators. This could be inlined in the definition of
[t] but is more concise this way *)
type (_, _) kind =
| Monomorphic : ('a any, monomorphic) kind
| Polymorphic : ('a any, polymorphic) kind
| Overloaded : ([< desugared ], overloaded) kind
| Resolved : ([< scopelang | dcalc | lcalc ], resolved) kind
type (_, _) t =
(* unary *)
(* * monomorphic *)
| Not : ('a any, monomorphic) t
(* Todo: [AToB] operators could actually be overloaded [ToB] operators*)
| IntToRat : ('a any, monomorphic) t
| MoneyToRat : ('a any, monomorphic) t
| RatToMoney : ('a any, monomorphic) t
| GetDay : ('a any, monomorphic) t
| GetMonth : ('a any, monomorphic) t
| GetYear : ('a any, monomorphic) t
| FirstDayOfMonth : ('a any, monomorphic) t
| LastDayOfMonth : ('a any, monomorphic) t
| RoundMoney : ('a any, monomorphic) t
| RoundDecimal : ('a any, monomorphic) t
(* * polymorphic *)
| Length : ('a any, polymorphic) t
| Log : log_entry * Uid.MarkedString.info list -> ('a any, polymorphic) t
(* * overloaded *)
| Minus : (desugared, overloaded) t
| Minus_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Minus_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Minus_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Minus_dur : ([< scopelang | dcalc | lcalc ], resolved) t
(* binary *)
(* * monomorphic *)
| And : ('a any, monomorphic) t
| Or : ('a any, monomorphic) t
| Xor : ('a any, monomorphic) t
(* * polymorphic *)
| Eq : ('a any, polymorphic) t
| Map : ('a any, polymorphic) t
| Concat : ('a any, polymorphic) t
| Filter : ('a any, polymorphic) t
(* * overloaded *)
| Add : (desugared, overloaded) t
| Add_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_dat_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Add_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub : (desugared, overloaded) t
| Sub_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_dat_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Sub_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult : (desugared, overloaded) t
| Mult_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult_mon_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Mult_dur_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Div : (desugared, overloaded) t
| Div_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Div_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Div_mon_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Div_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt : (desugared, overloaded) t
| Lt_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lt_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte : (desugared, overloaded) t
| Lte_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Lte_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt : (desugared, overloaded) t
| Gt_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gt_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte : (desugared, overloaded) t
| Gte_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
| Gte_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
(* Todo: Eq is not an overload at the moment, but it should be one. The
trick is that it needs generation of specific code for arrays, every
struct and enum: operators [Eq_structs of StructName.t], etc. *)
| Eq_int_int : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_rat_rat : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_mon_mon : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_dur_dur : ([< scopelang | dcalc | lcalc ], resolved) t
| Eq_dat_dat : ([< scopelang | dcalc | lcalc ], resolved) t
(* ternary *)
(* * polymorphic *)
| Fold : ('a any, polymorphic) t
end
type ('a, 'k) operator = ('a any, 'k) Op.t
type except = ConflictError | EmptyError | NoValueProvided | Crash
(** {2 Generic expressions} *)
@ -175,7 +255,9 @@ type ('a, 't) gexpr = (('a, 't) naked_gexpr, 't) Marked.t
- To write a function that handles cases from different ASTs, explicit the
type variables: [fun (type a) (x: a naked_gexpr) -> ...]
- For recursive functions, you may need to additionally explicit the
generalisation of the variable: [let rec f: type a . a naked_gexpr -> ...] *)
generalisation of the variable: [let rec f: type a . a naked_gexpr -> ...]
- Always think of using the pre-defined map/fold functions in [Expr] rather
than completely defining your recursion manually. *)
and ('a, 't) naked_gexpr =
(* Constructors common to all ASTs *)
@ -185,7 +267,7 @@ and ('a, 't) naked_gexpr =
args : ('a, 't) gexpr list;
}
-> ('a any, 't) naked_gexpr
| EOp : 'a operator -> ('a any, 't) naked_gexpr
| EOp : { op : ('a, _) operator; tys : typ list } -> ('a any, 't) naked_gexpr
| EArray : ('a, 't) gexpr list -> ('a any, 't) naked_gexpr
| EVar : ('a, 't) naked_gexpr Bindlib.var -> ('a any, 't) naked_gexpr
| EAbs : {

View File

@ -90,7 +90,7 @@ let eabs binder tys mark =
let eapp f args = Box.app1n f args @@ fun f args -> EApp { f; args }
let eassert e1 = Box.app1 e1 @@ fun e1 -> EAssert e1
let eop op = Box.app0 @@ EOp op
let eop op tys = Box.app0 @@ EOp { op; tys }
let edefault excepts just cons =
Box.app2n just cons excepts
@ -212,7 +212,7 @@ let map
match Marked.unmark e with
| ELit l -> elit l m
| EApp { f = e1; args } -> eapp (f e1) (List.map f args) m
| EOp op -> eop op m
| EOp { op; tys } -> eop op tys m
| EArray args -> earray (List.map f args) m
| EVar v -> evar (Var.translate v) m
| EAbs { binder; tys } ->
@ -302,7 +302,7 @@ let map_gather
let acc1, f = f e1 in
let acc2, args = lfoldmap args in
join acc1 acc2, eapp f args m
| EOp op -> acc, eop op m
| EOp { op; tys } -> acc, eop op tys m
| EArray args ->
let acc, args = lfoldmap args in
acc, earray args m
@ -396,99 +396,36 @@ let is_value (type a) (e : (a, _) gexpr) =
| ELit _ | EAbs _ | EOp _ | ERaise _ -> true
| _ -> false
let equal_tlit l1 l2 = l1 = l2
let compare_tlit l1 l2 = Stdlib.compare l1 l2
let rec equal_typ ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> equal_typ_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> equal_typ t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> equal_typ t1 t2 && equal_typ t1' t2'
| TArray t1, TArray t2 -> equal_typ t1 t2
| TAny, TAny -> true
| ( ( TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _
| TArray _ | TAny ),
_ ) ->
false
and equal_typ_list tys1 tys2 =
try List.for_all2 equal_typ tys1 tys2 with Invalid_argument _ -> false
(* Similar to [equal_typ], but allows TAny holes *)
let rec unifiable ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TAny, _ | _, TAny -> true
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> unifiable_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> unifiable t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> unifiable t1 t2 && unifiable t1' t2'
| TArray t1, TArray t2 -> unifiable t1 t2
| ( (TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _ | TArray _),
_ ) ->
false
and unifiable_list tys1 tys2 =
try List.for_all2 unifiable tys1 tys2 with Invalid_argument _ -> false
let rec compare_typ ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> compare_tlit l1 l2
| TTuple tys1, TTuple tys2 -> List.compare compare_typ tys1 tys2
| TStruct n1, TStruct n2 -> StructName.compare n1 n2
| TEnum en1, TEnum en2 -> EnumName.compare en1 en2
| TOption t1, TOption t2 -> compare_typ t1 t2
| TArrow (a1, b1), TArrow (a2, b2) -> (
match compare_typ a1 a2 with 0 -> compare_typ b1 b2 | n -> n)
| TArray t1, TArray t2 -> compare_typ t1 t2
| TAny, TAny -> 0
| TLit _, _ -> -1
| _, TLit _ -> 1
| TTuple _, _ -> -1
| _, TTuple _ -> 1
| TStruct _, _ -> -1
| _, TStruct _ -> 1
| TEnum _, _ -> -1
| _, TEnum _ -> 1
| TOption _, _ -> -1
| _, TOption _ -> 1
| TArrow _, _ -> -1
| _, TArrow _ -> 1
| TArray _, _ -> -1
| _, TArray _ -> 1
let equal_lit (type a) (l1 : a glit) (l2 : a glit) =
let open Runtime.Oper in
match l1, l2 with
| LBool b1, LBool b2 -> Bool.equal b1 b2
| LBool b1, LBool b2 -> not (o_xor b1 b2)
| LEmptyError, LEmptyError -> true
| LInt n1, LInt n2 -> Runtime.( =! ) n1 n2
| LRat r1, LRat r2 -> Runtime.( =& ) r1 r2
| LMoney m1, LMoney m2 -> Runtime.( =$ ) m1 m2
| LInt n1, LInt n2 -> o_eq_int_int n1 n2
| LRat r1, LRat r2 -> o_eq_rat_rat r1 r2
| LMoney m1, LMoney m2 -> o_eq_mon_mon m1 m2
| LUnit, LUnit -> true
| LDate d1, LDate d2 -> Runtime.( =@ ) d1 d2
| LDuration d1, LDuration d2 -> Runtime.( =^ ) d1 d2
| LDate d1, LDate d2 -> o_eq_dat_dat d1 d2
| LDuration d1, LDuration d2 -> o_eq_dur_dur d1 d2
| ( ( LBool _ | LEmptyError | LInt _ | LRat _ | LMoney _ | LUnit | LDate _
| LDuration _ ),
_ ) ->
false
let compare_lit (type a) (l1 : a glit) (l2 : a glit) =
let open Runtime.Oper in
match l1, l2 with
| LBool b1, LBool b2 -> Bool.compare b1 b2
| LEmptyError, LEmptyError -> 0
| LInt n1, LInt n2 ->
if Runtime.( <! ) n1 n2 then -1 else if Runtime.( =! ) n1 n2 then 0 else 1
if o_lt_int_int n1 n2 then -1 else if o_eq_int_int n1 n2 then 0 else 1
| LRat r1, LRat r2 ->
if Runtime.( <& ) r1 r2 then -1 else if Runtime.( =& ) r1 r2 then 0 else 1
if o_lt_rat_rat r1 r2 then -1 else if o_eq_rat_rat r1 r2 then 0 else 1
| LMoney m1, LMoney m2 ->
if Runtime.( <$ ) m1 m2 then -1 else if Runtime.( =$ ) m1 m2 then 0 else 1
if o_lt_mon_mon m1 m2 then -1 else if o_eq_mon_mon m1 m2 then 0 else 1
| LUnit, LUnit -> 0
| LDate d1, LDate d2 ->
if Runtime.( <@ ) d1 d2 then -1 else if Runtime.( =@ ) d1 d2 then 0 else 1
if o_lt_dat_dat d1 d2 then -1 else if o_eq_dat_dat d1 d2 then 0 else 1
| LDuration d1, LDuration d2 -> (
(* Duration comparison in the runtime may fail, so rely on a basic
lexicographic comparison instead *)
@ -540,119 +477,6 @@ let compare_location
| _, SubScopeVar _ -> .
let equal_location a b = compare_location a b = 0
let equal_log_entries l1 l2 =
match l1, l2 with
| VarDef t1, VarDef t2 -> equal_typ (t1, Pos.no_pos) (t2, Pos.no_pos)
| x, y -> x = y
let compare_log_entries l1 l2 =
match l1, l2 with
| VarDef t1, VarDef t2 -> compare_typ (t1, Pos.no_pos) (t2, Pos.no_pos)
| BeginCall, BeginCall
| EndCall, EndCall
| PosRecordIfTrueBool, PosRecordIfTrueBool ->
0
| VarDef _, _ -> -1
| _, VarDef _ -> 1
| BeginCall, _ -> -1
| _, BeginCall -> 1
| EndCall, _ -> -1
| _, EndCall -> 1
| PosRecordIfTrueBool, _ -> .
| _, PosRecordIfTrueBool -> .
(* let equal_op_kind = Stdlib.(=) *)
let compare_op_kind = Stdlib.compare
let equal_unops op1 op2 =
match op1, op2 with
(* Log entries contain a typ which contain position information, we thus need
to descend into them *)
| Log (l1, info1), Log (l2, info2) ->
equal_log_entries l1 l2 && List.equal Uid.MarkedString.equal info1 info2
| Log _, _ | _, Log _ -> false
(* All the other cases can be discharged through equality *)
| ( ( Not | Minus _ | Length | IntToRat | MoneyToRat | RatToMoney | GetDay
| GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | RoundMoney
| RoundDecimal ),
_ ) ->
op1 = op2
let compare_unops op1 op2 =
match op1, op2 with
| Not, Not -> 0
| Minus k1, Minus k2 -> compare_op_kind k1 k2
| Log (l1, info1), Log (l2, info2) -> (
match compare_log_entries l1 l2 with
| 0 -> List.compare Uid.MarkedString.compare info1 info2
| n -> n)
| Length, Length
| IntToRat, IntToRat
| MoneyToRat, MoneyToRat
| RatToMoney, RatToMoney
| GetDay, GetDay
| GetMonth, GetMonth
| GetYear, GetYear
| FirstDayOfMonth, FirstDayOfMonth
| LastDayOfMonth, LastDayOfMonth
| RoundMoney, RoundMoney
| RoundDecimal, RoundDecimal ->
0
| Not, _ -> -1
| _, Not -> 1
| Minus _, _ -> -1
| _, Minus _ -> 1
| Log _, _ -> -1
| _, Log _ -> 1
| Length, _ -> -1
| _, Length -> 1
| IntToRat, _ -> -1
| _, IntToRat -> 1
| MoneyToRat, _ -> -1
| _, MoneyToRat -> 1
| RatToMoney, _ -> -1
| _, RatToMoney -> 1
| GetDay, _ -> -1
| _, GetDay -> 1
| GetMonth, _ -> -1
| _, GetMonth -> 1
| GetYear, _ -> -1
| _, GetYear -> 1
| FirstDayOfMonth, _ -> -1
| _, FirstDayOfMonth -> 1
| LastDayOfMonth, _ -> -1
| _, LastDayOfMonth -> 1
| RoundMoney, _ -> -1
| _, RoundMoney -> 1
| RoundDecimal, _ -> .
| _, RoundDecimal -> .
let equal_binop = Stdlib.( = )
let compare_binop = Stdlib.compare
let equal_ternop = Stdlib.( = )
let compare_ternop = Stdlib.compare
let equal_ops op1 op2 =
match op1, op2 with
| Ternop op1, Ternop op2 -> equal_ternop op1 op2
| Binop op1, Binop op2 -> equal_binop op1 op2
| Unop op1, Unop op2 -> equal_unops op1 op2
| _, _ -> false
let compare_op op1 op2 =
match op1, op2 with
| Ternop op1, Ternop op2 -> compare_ternop op1 op2
| Binop op1, Binop op2 -> compare_binop op1 op2
| Unop op1, Unop op2 -> compare_unops op1 op2
| Ternop _, _ -> -1
| _, Ternop _ -> 1
| Binop _, _ -> -1
| _, Binop _ -> 1
| Unop _, _ -> .
| _, Unop _ -> .
let equal_except ex1 ex2 = ex1 = ex2
let compare_except ex1 ex2 = Stdlib.compare ex1 ex2
@ -673,7 +497,7 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool =
| EArray es1, EArray es2 -> equal_list es1 es2
| ELit l1, ELit l2 -> l1 = l2
| EAbs { binder = b1; tys = tys1 }, EAbs { binder = b2; tys = tys2 } ->
equal_typ_list tys1 tys2
Type.equal_list tys1 tys2
&&
let vars1, body1 = Bindlib.unmbind b1 in
let body2 = Bindlib.msubst b2 (Array.map (fun x -> EVar x) vars1) in
@ -681,7 +505,8 @@ and equal : type a. (a, 't) gexpr -> (a, 't) gexpr -> bool =
| EApp { f = e1; args = args1 }, EApp { f = e2; args = args2 } ->
equal e1 e2 && equal_list args1 args2
| EAssert e1, EAssert e2 -> equal e1 e2
| EOp op1, EOp op2 -> equal_ops op1 op2
| EOp { op = op1; tys = tys1 }, EOp { op = op2; tys = tys2 } ->
Operator.equal op1 op2 && Type.equal_list tys1 tys2
| ( EDefault { excepts = exc1; just = def1; cons = cons1 },
EDefault { excepts = exc2; just = def2; cons = cons2 } ) ->
equal def1 def2 && equal cons1 cons2 && equal_list exc1 exc2
@ -734,15 +559,16 @@ let rec compare : type a. (a, _) gexpr -> (a, _) gexpr -> int =
| EApp {f=f1; args=args1}, EApp {f=f2; args=args2} ->
compare f1 f2 @@< fun () ->
List.compare compare args1 args2
| EOp op1, EOp op2 ->
compare_op op1 op2
| EOp {op=op1; tys=tys1}, EOp {op=op2; tys=tys2} ->
Operator.compare op1 op2 @@< fun () ->
List.compare Type.compare tys1 tys2
| EArray a1, EArray a2 ->
List.compare compare a1 a2
| EVar v1, EVar v2 ->
Bindlib.compare_vars v1 v2
| EAbs {binder=binder1; tys=typs1},
EAbs {binder=binder2; tys=typs2} ->
List.compare compare_typ typs1 typs2 @@< fun () ->
List.compare Type.compare typs1 typs2 @@< fun () ->
let _, e1, e2 = Bindlib.unmbind2 binder1 binder2 in
compare e1 e2
| EIfThenElse {cond=i1; etrue=t1; efalse=e1},
@ -835,7 +661,7 @@ let rec free_vars : type a. (a, 't) gexpr -> (a, 't) gexpr Var.Set.t = function
let remove_logging_calls e =
let rec f e =
match Marked.unmark e with
| EApp { f = EOp (Unop (Log _)), _; args = [arg] } -> map ~f arg
| EApp { f = EOp { op = Log _; _ }, _; args = [arg] } -> map ~f arg
| _ -> map ~f e
in
f e
@ -903,7 +729,7 @@ let make_app e u pos =
(fun tf tx ->
match Marked.unmark tf with
| TArrow (tx', tr) ->
assert (unifiable tx.ty tx');
assert (Type.unifiable tx.ty tx');
(* wrong arg type *)
tr
| TAny -> tf
@ -930,7 +756,7 @@ let make_multiple_let_in xs taus e1s e2 mpos =
let make_default_unboxed excepts just cons =
let rec bool_value = function
| ELit (LBool b), _ -> Some b
| EApp { f = EOp (Unop (Log (l, _))), _; args = [e]; _ }, _
| EApp { f = EOp { op = Log (l, _); _ }, _; args = [e]; _ }, _
when l <> PosRecordIfTrueBool
(* we don't remove the log calls corresponding to source code
definitions !*) ->
@ -959,33 +785,3 @@ let make_tuple el m0 =
(List.map (fun e -> Marked.get_mark e) el)
in
etuple el m
let translate_op_kind : type a. a op_kind -> 'b op_kind = function
| KInt -> KInt
| KRat -> KRat
| KMoney -> KMoney
| KDate -> KDate
| KDuration -> KDuration
let translate_op : type a. a operator -> 'b operator = function
| Ternop o -> Ternop o
| Binop o ->
Binop
(match o with
| Add k -> Add (translate_op_kind k)
| Sub k -> Sub (translate_op_kind k)
| Mult k -> Mult (translate_op_kind k)
| Div k -> Div (translate_op_kind k)
| Lt k -> Lt (translate_op_kind k)
| Lte k -> Lte (translate_op_kind k)
| Gt k -> Gt (translate_op_kind k)
| Gte k -> Gte (translate_op_kind k)
| (And | Or | Xor | Eq | Neq | Map | Concat | Filter) as o -> o)
| Unop o ->
Unop
(match o with
| Minus k -> Minus (translate_op_kind k)
| ( Not | Log _ | Length | IntToRat | MoneyToRat | RatToMoney | GetDay
| GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth | RoundMoney
| RoundDecimal ) as o ->
o)

View File

@ -66,7 +66,7 @@ val eapp :
val eassert :
(([< dcalc | lcalc ] as 'a), 't) boxed_gexpr -> 't -> ('a, 't) boxed_gexpr
val eop : 'a any operator -> 't -> ('a, 't) boxed_gexpr
val eop : ('a any, 'k) operator -> typ list -> 't -> ('a, 't) boxed_gexpr
val edefault :
(([< desugared | scopelang | dcalc ] as 'a), 't) boxed_gexpr list ->
@ -310,11 +310,6 @@ val make_tuple :
(** {2 Transformations} *)
val translate_op :
[< desugared | scopelang | dcalc | lcalc ] operator -> 'b any operator
(** Operators are actually all the same after initial desambiguation, so this
function allows converting their types ; otherwise, this is the identity *)
val remove_logging_calls : ('a any, 't) gexpr -> ('a, 't) boxed_gexpr
(** Removes all calls to [Log] unary operators in the AST, replacing them by
their argument. *)
@ -340,8 +335,6 @@ val compare : ('a, 't) gexpr -> ('a, 't) gexpr -> int
(** Standard comparison function, suitable for e.g. [Set.Make]. Ignores position
information *)
val equal_typ : typ -> typ -> bool
val compare_typ : typ -> typ -> int
val is_value : ('a any, 't) gexpr -> bool
val free_vars : ('a any, 't) gexpr -> ('a, 't) gexpr Var.Set.t

View File

@ -0,0 +1,546 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Definitions
include Definitions.Op
let name : type a k. (a, k) t -> string = function
| Not -> "o_not"
| Length -> "o_length"
| IntToRat -> "o_intToRat"
| MoneyToRat -> "o_moneyToRat"
| RatToMoney -> "o_ratToMoney"
| GetDay -> "o_getDay"
| GetMonth -> "o_getMonth"
| GetYear -> "o_getYear"
| FirstDayOfMonth -> "o_firstDayOfMonth"
| LastDayOfMonth -> "o_lastDayOfMonth"
| RoundMoney -> "o_roundMoney"
| RoundDecimal -> "o_roundDecimal"
| Log _ -> "o_log"
| Minus -> "o_minus"
| Minus_int -> "o_minus_int"
| Minus_rat -> "o_minus_rat"
| Minus_mon -> "o_minus_mon"
| Minus_dur -> "o_minus_dur"
| And -> "o_and"
| Or -> "o_or"
| Xor -> "o_xor"
| Eq -> "o_eq"
| Map -> "o_map"
| Concat -> "o_concat"
| Filter -> "o_filter"
| Add -> "o_add"
| Add_int_int -> "o_add_int_int"
| Add_rat_rat -> "o_add_rat_rat"
| Add_mon_mon -> "o_add_mon_mon"
| Add_dat_dur -> "o_add_dat_dur"
| Add_dur_dur -> "o_add_dur_dur"
| Sub -> "o_sub"
| Sub_int_int -> "o_sub_int_int"
| Sub_rat_rat -> "o_sub_rat_rat"
| Sub_mon_mon -> "o_sub_mon_mon"
| Sub_dat_dat -> "o_sub_dat_dat"
| Sub_dat_dur -> "o_sub_dat_dur"
| Sub_dur_dur -> "o_sub_dur_dur"
| Mult -> "o_mult"
| Mult_int_int -> "o_mult_int_int"
| Mult_rat_rat -> "o_mult_rat_rat"
| Mult_mon_rat -> "o_mult_mon_rat"
| Mult_dur_int -> "o_mult_dur_int"
| Div -> "o_div"
| Div_int_int -> "o_div_int_int"
| Div_rat_rat -> "o_div_rat_rat"
| Div_mon_mon -> "o_div_mon_mon"
| Div_mon_rat -> "o_div_mon_mon"
| Lt -> "o_lt"
| Lt_int_int -> "o_lt_int_int"
| Lt_rat_rat -> "o_lt_rat_rat"
| Lt_mon_mon -> "o_lt_mon_mon"
| Lt_dur_dur -> "o_lt_dur_dur"
| Lt_dat_dat -> "o_lt_dat_dat"
| Lte -> "o_lte"
| Lte_int_int -> "o_lte_int_int"
| Lte_rat_rat -> "o_lte_rat_rat"
| Lte_mon_mon -> "o_lte_mon_mon"
| Lte_dur_dur -> "o_lte_dur_dur"
| Lte_dat_dat -> "o_lte_dat_dat"
| Gt -> "o_gt"
| Gt_int_int -> "o_gt_int_int"
| Gt_rat_rat -> "o_gt_rat_rat"
| Gt_mon_mon -> "o_gt_mon_mon"
| Gt_dur_dur -> "o_gt_dur_dur"
| Gt_dat_dat -> "o_gt_dat_dat"
| Gte -> "o_gte"
| Gte_int_int -> "o_gte_int_int"
| Gte_rat_rat -> "o_gte_rat_rat"
| Gte_mon_mon -> "o_gte_mon_mon"
| Gte_dur_dur -> "o_gte_dur_dur"
| Gte_dat_dat -> "o_gte_dat_dat"
| Eq_int_int -> "o_eq_int_int"
| Eq_rat_rat -> "o_eq_rat_rat"
| Eq_mon_mon -> "o_eq_mon_mon"
| Eq_dur_dur -> "o_eq_dur_dur"
| Eq_dat_dat -> "o_eq_dat_dat"
| Fold -> "o_fold"
let compare_log_entries l1 l2 =
match l1, l2 with
| VarDef t1, VarDef t2 -> Type.compare (t1, Pos.no_pos) (t2, Pos.no_pos)
| BeginCall, BeginCall
| EndCall, EndCall
| PosRecordIfTrueBool, PosRecordIfTrueBool ->
0
| VarDef _, _ -> -1
| _, VarDef _ -> 1
| BeginCall, _ -> -1
| _, BeginCall -> 1
| EndCall, _ -> -1
| _, EndCall -> 1
| PosRecordIfTrueBool, _ -> .
| _, PosRecordIfTrueBool -> .
let compare (type a k a2 k2) (t1 : (a, k) t) (t2 : (a2, k2) t) =
match[@ocamlformat "disable"] t1, t2 with
| Log (l1, info1), Log (l2, info2) -> (
match compare_log_entries l1 l2 with
| 0 -> List.compare Uid.MarkedString.compare info1 info2
| n -> n)
| Not, Not
| Length, Length
| IntToRat, IntToRat
| MoneyToRat, MoneyToRat
| RatToMoney, RatToMoney
| GetDay, GetDay
| GetMonth, GetMonth
| GetYear, GetYear
| FirstDayOfMonth, FirstDayOfMonth
| LastDayOfMonth, LastDayOfMonth
| RoundMoney, RoundMoney
| RoundDecimal, RoundDecimal
| Minus, Minus
| Minus_int, Minus_int
| Minus_rat, Minus_rat
| Minus_mon, Minus_mon
| Minus_dur, Minus_dur
| And, And
| Or, Or
| Xor, Xor
| Eq, Eq
| Map, Map
| Concat, Concat
| Filter, Filter
| Add, Add
| Add_int_int, Add_int_int
| Add_rat_rat, Add_rat_rat
| Add_mon_mon, Add_mon_mon
| Add_dat_dur, Add_dat_dur
| Add_dur_dur, Add_dur_dur
| Sub, Sub
| Sub_int_int, Sub_int_int
| Sub_rat_rat, Sub_rat_rat
| Sub_mon_mon, Sub_mon_mon
| Sub_dat_dat, Sub_dat_dat
| Sub_dat_dur, Sub_dat_dur
| Sub_dur_dur, Sub_dur_dur
| Mult, Mult
| Mult_int_int, Mult_int_int
| Mult_rat_rat, Mult_rat_rat
| Mult_mon_rat, Mult_mon_rat
| Mult_dur_int, Mult_dur_int
| Div, Div
| Div_int_int, Div_int_int
| Div_rat_rat, Div_rat_rat
| Div_mon_mon, Div_mon_mon
| Div_mon_rat, Div_mon_rat
| Lt, Lt
| Lt_int_int, Lt_int_int
| Lt_rat_rat, Lt_rat_rat
| Lt_mon_mon, Lt_mon_mon
| Lt_dat_dat, Lt_dat_dat
| Lt_dur_dur, Lt_dur_dur
| Lte, Lte
| Lte_int_int, Lte_int_int
| Lte_rat_rat, Lte_rat_rat
| Lte_mon_mon, Lte_mon_mon
| Lte_dat_dat, Lte_dat_dat
| Lte_dur_dur, Lte_dur_dur
| Gt, Gt
| Gt_int_int, Gt_int_int
| Gt_rat_rat, Gt_rat_rat
| Gt_mon_mon, Gt_mon_mon
| Gt_dat_dat, Gt_dat_dat
| Gt_dur_dur, Gt_dur_dur
| Gte, Gte
| Gte_int_int, Gte_int_int
| Gte_rat_rat, Gte_rat_rat
| Gte_mon_mon, Gte_mon_mon
| Gte_dat_dat, Gte_dat_dat
| Gte_dur_dur, Gte_dur_dur
| Eq_int_int, Eq_int_int
| Eq_rat_rat, Eq_rat_rat
| Eq_mon_mon, Eq_mon_mon
| Eq_dat_dat, Eq_dat_dat
| Eq_dur_dur, Eq_dur_dur
| Fold, Fold -> 0
| Not, _ -> -1 | _, Not -> 1
| Length, _ -> -1 | _, Length -> 1
| IntToRat, _ -> -1 | _, IntToRat -> 1
| MoneyToRat, _ -> -1 | _, MoneyToRat -> 1
| RatToMoney, _ -> -1 | _, RatToMoney -> 1
| GetDay, _ -> -1 | _, GetDay -> 1
| GetMonth, _ -> -1 | _, GetMonth -> 1
| GetYear, _ -> -1 | _, GetYear -> 1
| FirstDayOfMonth, _ -> -1 | _, FirstDayOfMonth -> 1
| LastDayOfMonth, _ -> -1 | _, LastDayOfMonth -> 1
| RoundMoney, _ -> -1 | _, RoundMoney -> 1
| RoundDecimal, _ -> -1 | _, RoundDecimal -> 1
| Log _, _ -> -1 | _, Log _ -> 1
| Minus, _ -> -1 | _, Minus -> 1
| Minus_int, _ -> -1 | _, Minus_int -> 1
| Minus_rat, _ -> -1 | _, Minus_rat -> 1
| Minus_mon, _ -> -1 | _, Minus_mon -> 1
| Minus_dur, _ -> -1 | _, Minus_dur -> 1
| And, _ -> -1 | _, And -> 1
| Or, _ -> -1 | _, Or -> 1
| Xor, _ -> -1 | _, Xor -> 1
| Eq, _ -> -1 | _, Eq -> 1
| Map, _ -> -1 | _, Map -> 1
| Concat, _ -> -1 | _, Concat -> 1
| Filter, _ -> -1 | _, Filter -> 1
| Add, _ -> -1 | _, Add -> 1
| Add_int_int, _ -> -1 | _, Add_int_int -> 1
| Add_rat_rat, _ -> -1 | _, Add_rat_rat -> 1
| Add_mon_mon, _ -> -1 | _, Add_mon_mon -> 1
| Add_dat_dur, _ -> -1 | _, Add_dat_dur -> 1
| Add_dur_dur, _ -> -1 | _, Add_dur_dur -> 1
| Sub, _ -> -1 | _, Sub -> 1
| Sub_int_int, _ -> -1 | _, Sub_int_int -> 1
| Sub_rat_rat, _ -> -1 | _, Sub_rat_rat -> 1
| Sub_mon_mon, _ -> -1 | _, Sub_mon_mon -> 1
| Sub_dat_dat, _ -> -1 | _, Sub_dat_dat -> 1
| Sub_dat_dur, _ -> -1 | _, Sub_dat_dur -> 1
| Sub_dur_dur, _ -> -1 | _, Sub_dur_dur -> 1
| Mult, _ -> -1 | _, Mult -> 1
| Mult_int_int, _ -> -1 | _, Mult_int_int -> 1
| Mult_rat_rat, _ -> -1 | _, Mult_rat_rat -> 1
| Mult_mon_rat, _ -> -1 | _, Mult_mon_rat -> 1
| Mult_dur_int, _ -> -1 | _, Mult_dur_int -> 1
| Div, _ -> -1 | _, Div -> 1
| Div_int_int, _ -> -1 | _, Div_int_int -> 1
| Div_rat_rat, _ -> -1 | _, Div_rat_rat -> 1
| Div_mon_mon, _ -> -1 | _, Div_mon_mon -> 1
| Div_mon_rat, _ -> -1 | _, Div_mon_rat -> 1
| Lt, _ -> -1 | _, Lt -> 1
| Lt_int_int, _ -> -1 | _, Lt_int_int -> 1
| Lt_rat_rat, _ -> -1 | _, Lt_rat_rat -> 1
| Lt_mon_mon, _ -> -1 | _, Lt_mon_mon -> 1
| Lt_dat_dat, _ -> -1 | _, Lt_dat_dat -> 1
| Lt_dur_dur, _ -> -1 | _, Lt_dur_dur -> 1
| Lte, _ -> -1 | _, Lte -> 1
| Lte_int_int, _ -> -1 | _, Lte_int_int -> 1
| Lte_rat_rat, _ -> -1 | _, Lte_rat_rat -> 1
| Lte_mon_mon, _ -> -1 | _, Lte_mon_mon -> 1
| Lte_dat_dat, _ -> -1 | _, Lte_dat_dat -> 1
| Lte_dur_dur, _ -> -1 | _, Lte_dur_dur -> 1
| Gt, _ -> -1 | _, Gt -> 1
| Gt_int_int, _ -> -1 | _, Gt_int_int -> 1
| Gt_rat_rat, _ -> -1 | _, Gt_rat_rat -> 1
| Gt_mon_mon, _ -> -1 | _, Gt_mon_mon -> 1
| Gt_dat_dat, _ -> -1 | _, Gt_dat_dat -> 1
| Gt_dur_dur, _ -> -1 | _, Gt_dur_dur -> 1
| Gte, _ -> -1 | _, Gte -> 1
| Gte_int_int, _ -> -1 | _, Gte_int_int -> 1
| Gte_rat_rat, _ -> -1 | _, Gte_rat_rat -> 1
| Gte_mon_mon, _ -> -1 | _, Gte_mon_mon -> 1
| Gte_dat_dat, _ -> -1 | _, Gte_dat_dat -> 1
| Gte_dur_dur, _ -> -1 | _, Gte_dur_dur -> 1
| Eq_int_int, _ -> -1 | _, Eq_int_int -> 1
| Eq_rat_rat, _ -> -1 | _, Eq_rat_rat -> 1
| Eq_mon_mon, _ -> -1 | _, Eq_mon_mon -> 1
| Eq_dat_dat, _ -> -1 | _, Eq_dat_dat -> 1
| Eq_dur_dur, _ -> -1 | _, Eq_dur_dur -> 1
| Fold, _ | _, Fold -> .
let equal (type a k a2 k2) (t1 : (a, k) t) (t2 : (a2, k2) t) = compare t1 t2 = 0
(* Classification of operators *)
let kind_dispatch :
type a b k.
polymorphic:((_, polymorphic) t -> b) ->
monomorphic:((_, monomorphic) t -> b) ->
?overloaded:((_, overloaded) t -> b) ->
?resolved:((_, resolved) t -> b) ->
(a, k) t ->
b =
fun ~polymorphic ~monomorphic ?(overloaded = fun _ -> assert false)
?(resolved = fun _ -> assert false) op ->
match op with
| ( Not | IntToRat | MoneyToRat | RatToMoney | GetDay | GetMonth | GetYear
| FirstDayOfMonth | LastDayOfMonth | RoundMoney | RoundDecimal | And | Or
| Xor ) as op ->
monomorphic op
| (Log _ | Length | Eq | Map | Concat | Filter | Fold) as op -> polymorphic op
| (Minus | Add | Sub | Mult | Div | Lt | Lte | Gt | Gte) as op ->
overloaded op
| ( Minus_int | Minus_rat | Minus_mon | Minus_dur | Add_int_int | Add_rat_rat
| Add_mon_mon | Add_dat_dur | Add_dur_dur | Sub_int_int | Sub_rat_rat
| Sub_mon_mon | Sub_dat_dat | Sub_dat_dur | Sub_dur_dur | Mult_int_int
| Mult_rat_rat | Mult_mon_rat | Mult_dur_int | Div_int_int | Div_rat_rat
| Div_mon_mon | Div_mon_rat | Lt_int_int | Lt_rat_rat | Lt_mon_mon
| Lt_dat_dat | Lt_dur_dur | Lte_int_int | Lte_rat_rat | Lte_mon_mon
| Lte_dat_dat | Lte_dur_dur | Gt_int_int | Gt_rat_rat | Gt_mon_mon
| Gt_dat_dat | Gt_dur_dur | Gte_int_int | Gte_rat_rat | Gte_mon_mon
| Gte_dat_dat | Gte_dur_dur | Eq_int_int | Eq_rat_rat | Eq_mon_mon
| Eq_dat_dat | Eq_dur_dur ) as op ->
resolved op
(* Glorified identity... allowed operators are the same in scopelang, dcalc,
lcalc *)
let translate :
type k.
([< scopelang | dcalc | lcalc ], k) t ->
([< scopelang | dcalc | lcalc ], k) t =
fun op ->
match op with
| Length -> Length
| Log (i, l) -> Log (i, l)
| Eq -> Eq
| Map -> Map
| Concat -> Concat
| Filter -> Filter
| Fold -> Fold
| Not -> Not
| IntToRat -> IntToRat
| MoneyToRat -> MoneyToRat
| RatToMoney -> RatToMoney
| GetDay -> GetDay
| GetMonth -> GetMonth
| GetYear -> GetYear
| FirstDayOfMonth -> FirstDayOfMonth
| LastDayOfMonth -> LastDayOfMonth
| RoundMoney -> RoundMoney
| RoundDecimal -> RoundDecimal
| And -> And
| Or -> Or
| Xor -> Xor
| Minus_int -> Minus_int
| Minus_rat -> Minus_rat
| Minus_mon -> Minus_mon
| Minus_dur -> Minus_dur
| Add_int_int -> Add_int_int
| Add_rat_rat -> Add_rat_rat
| Add_mon_mon -> Add_mon_mon
| Add_dat_dur -> Add_dat_dur
| Add_dur_dur -> Add_dur_dur
| Sub_int_int -> Sub_int_int
| Sub_rat_rat -> Sub_rat_rat
| Sub_mon_mon -> Sub_mon_mon
| Sub_dat_dat -> Sub_dat_dat
| Sub_dat_dur -> Sub_dat_dur
| Sub_dur_dur -> Sub_dur_dur
| Mult_int_int -> Mult_int_int
| Mult_rat_rat -> Mult_rat_rat
| Mult_mon_rat -> Mult_mon_rat
| Mult_dur_int -> Mult_dur_int
| Div_int_int -> Div_int_int
| Div_rat_rat -> Div_rat_rat
| Div_mon_mon -> Div_mon_mon
| Div_mon_rat -> Div_mon_rat
| Lt_int_int -> Lt_int_int
| Lt_rat_rat -> Lt_rat_rat
| Lt_mon_mon -> Lt_mon_mon
| Lt_dat_dat -> Lt_dat_dat
| Lt_dur_dur -> Lt_dur_dur
| Lte_int_int -> Lte_int_int
| Lte_rat_rat -> Lte_rat_rat
| Lte_mon_mon -> Lte_mon_mon
| Lte_dat_dat -> Lte_dat_dat
| Lte_dur_dur -> Lte_dur_dur
| Gt_int_int -> Gt_int_int
| Gt_rat_rat -> Gt_rat_rat
| Gt_mon_mon -> Gt_mon_mon
| Gt_dat_dat -> Gt_dat_dat
| Gt_dur_dur -> Gt_dur_dur
| Gte_int_int -> Gte_int_int
| Gte_rat_rat -> Gte_rat_rat
| Gte_mon_mon -> Gte_mon_mon
| Gte_dat_dat -> Gte_dat_dat
| Gte_dur_dur -> Gte_dur_dur
| Eq_int_int -> Eq_int_int
| Eq_rat_rat -> Eq_rat_rat
| Eq_mon_mon -> Eq_mon_mon
| Eq_dat_dat -> Eq_dat_dat
| Eq_dur_dur -> Eq_dur_dur
let monomorphic_type (op, pos) =
let ( @- ) a b = TArrow ((TLit a, pos), b), pos in
let ( @-> ) a b = TArrow ((TLit a, pos), (TLit b, pos)), pos in
match op with
| Not -> TBool @-> TBool
| IntToRat -> TInt @-> TRat
| MoneyToRat -> TMoney @-> TRat
| RatToMoney -> TRat @-> TMoney
| GetDay -> TDate @-> TInt
| GetMonth -> TDate @-> TInt
| GetYear -> TDate @-> TInt
| FirstDayOfMonth -> TDate @-> TDate
| LastDayOfMonth -> TDate @-> TDate
| RoundMoney -> TMoney @-> TMoney
| RoundDecimal -> TRat @-> TRat
| And -> TBool @- TBool @-> TBool
| Or -> TBool @- TBool @-> TBool
| Xor -> TBool @- TBool @-> TBool
let resolved_type (op, pos) =
let ( @- ) a b = TArrow ((TLit a, pos), b), pos in
let ( @-> ) a b = TArrow ((TLit a, pos), (TLit b, pos)), pos in
match op with
| Minus_int -> TInt @-> TInt
| Minus_rat -> TRat @-> TRat
| Minus_mon -> TMoney @-> TMoney
| Minus_dur -> TDuration @-> TDuration
| Add_int_int -> TInt @- TInt @-> TInt
| Add_rat_rat -> TRat @- TRat @-> TRat
| Add_mon_mon -> TMoney @- TMoney @-> TMoney
| Add_dat_dur -> TDate @- TDuration @-> TDate
| Add_dur_dur -> TDuration @- TDuration @-> TDuration
| Sub_int_int -> TInt @- TInt @-> TInt
| Sub_rat_rat -> TRat @- TRat @-> TRat
| Sub_mon_mon -> TMoney @- TMoney @-> TMoney
| Sub_dat_dat -> TDate @- TDate @-> TDuration
| Sub_dat_dur -> TDate @- TDuration @-> TDuration
| Sub_dur_dur -> TDuration @- TDuration @-> TDuration
| Mult_int_int -> TInt @- TInt @-> TInt
| Mult_rat_rat -> TRat @- TRat @-> TRat
| Mult_mon_rat -> TMoney @- TRat @-> TMoney
| Mult_dur_int -> TDuration @- TInt @-> TDuration
| Div_int_int -> TInt @- TInt @-> TInt
| Div_rat_rat -> TRat @- TRat @-> TRat
| Div_mon_mon -> TMoney @- TMoney @-> TRat
| Div_mon_rat -> TMoney @- TRat @-> TMoney
| Lt_int_int -> TInt @- TInt @-> TBool
| Lt_rat_rat -> TRat @- TRat @-> TBool
| Lt_mon_mon -> TMoney @- TMoney @-> TBool
| Lt_dat_dat -> TDate @- TDate @-> TBool
| Lt_dur_dur -> TDuration @- TDuration @-> TBool
| Lte_int_int -> TInt @- TInt @-> TBool
| Lte_rat_rat -> TRat @- TRat @-> TBool
| Lte_mon_mon -> TMoney @- TMoney @-> TBool
| Lte_dat_dat -> TDate @- TDate @-> TBool
| Lte_dur_dur -> TDuration @- TDuration @-> TBool
| Gt_int_int -> TInt @- TInt @-> TBool
| Gt_rat_rat -> TRat @- TRat @-> TBool
| Gt_mon_mon -> TMoney @- TMoney @-> TBool
| Gt_dat_dat -> TDate @- TDate @-> TBool
| Gt_dur_dur -> TDuration @- TDuration @-> TBool
| Gte_int_int -> TInt @- TInt @-> TBool
| Gte_rat_rat -> TRat @- TRat @-> TBool
| Gte_mon_mon -> TMoney @- TMoney @-> TBool
| Gte_dat_dat -> TDate @- TDate @-> TBool
| Gte_dur_dur -> TDuration @- TDuration @-> TBool
| Eq_int_int -> TInt @- TInt @-> TBool
| Eq_rat_rat -> TRat @- TRat @-> TBool
| Eq_mon_mon -> TMoney @- TMoney @-> TBool
| Eq_dat_dat -> TDate @- TDate @-> TBool
| Eq_dur_dur -> TDuration @- TDuration @-> TBool
let resolve_overload_aux (op : ('a, overloaded) t) (operands : typ_lit list) :
('b, resolved) t * [ `Straight | `Reversed ] =
match op, operands with
| Minus, [TInt] -> Minus_int, `Straight
| Minus, [TRat] -> Minus_rat, `Straight
| Minus, [TMoney] -> Minus_mon, `Straight
| Minus, [TDuration] -> Minus_dur, `Straight
| Add, [TInt; TInt] -> Add_int_int, `Straight
| Add, [TRat; TRat] -> Add_rat_rat, `Straight
| Add, [TMoney; TMoney] -> Add_mon_mon, `Straight
| Add, [TDuration; TDuration] -> Add_dur_dur, `Straight
| Add, [TDate; TDuration] -> Add_dat_dur, `Straight
| Add, [TDuration; TDate] -> Add_dat_dur, `Reversed
| Sub, [TInt; TInt] -> Sub_int_int, `Straight
| Sub, [TRat; TRat] -> Sub_rat_rat, `Straight
| Sub, [TMoney; TMoney] -> Sub_mon_mon, `Straight
| Sub, [TDuration; TDuration] -> Sub_dur_dur, `Straight
| Sub, [TDate; TDate] -> Sub_dat_dat, `Straight
| Sub, [TDate; TDuration] -> Sub_dat_dur, `Straight
| Mult, [TInt; TInt] -> Mult_int_int, `Straight
| Mult, [TRat; TRat] -> Mult_rat_rat, `Straight
| Mult, [TMoney; TRat] -> Mult_mon_rat, `Straight
| Mult, [TRat; TMoney] -> Mult_mon_rat, `Reversed
| Mult, [TDuration; TInt] -> Mult_dur_int, `Straight
| Mult, [TInt; TDuration] -> Mult_dur_int, `Reversed
| Div, [TInt; TInt] -> Div_int_int, `Straight
| Div, [TRat; TRat] -> Div_rat_rat, `Straight
| Div, [TMoney; TMoney] -> Div_mon_mon, `Straight
| Div, [TMoney; TRat] -> Div_mon_rat, `Straight
| Lt, [TInt; TInt] -> Lt_int_int, `Straight
| Lt, [TRat; TRat] -> Lt_rat_rat, `Straight
| Lt, [TMoney; TMoney] -> Lt_mon_mon, `Straight
| Lt, [TDuration; TDuration] -> Lt_dur_dur, `Straight
| Lt, [TDate; TDate] -> Lt_dat_dat, `Straight
| Lte, [TInt; TInt] -> Lte_int_int, `Straight
| Lte, [TRat; TRat] -> Lte_rat_rat, `Straight
| Lte, [TMoney; TMoney] -> Lte_mon_mon, `Straight
| Lte, [TDuration; TDuration] -> Lte_dur_dur, `Straight
| Lte, [TDate; TDate] -> Lte_dat_dat, `Straight
| Gt, [TInt; TInt] -> Gt_int_int, `Straight
| Gt, [TRat; TRat] -> Gt_rat_rat, `Straight
| Gt, [TMoney; TMoney] -> Gt_mon_mon, `Straight
| Gt, [TDuration; TDuration] -> Gt_dur_dur, `Straight
| Gt, [TDate; TDate] -> Gt_dat_dat, `Straight
| Gte, [TInt; TInt] -> Gte_int_int, `Straight
| Gte, [TRat; TRat] -> Gte_rat_rat, `Straight
| Gte, [TMoney; TMoney] -> Gte_mon_mon, `Straight
| Gte, [TDuration; TDuration] -> Gte_dur_dur, `Straight
| Gte, [TDate; TDate] -> Gte_dat_dat, `Straight
| (Minus | Add | Sub | Mult | Div | Lt | Lte | Gt | Gte), _ -> raise Not_found
let resolve_overload
ctx
(op : ('a, overloaded) t Marked.pos)
(operands : typ list) : ('b, resolved) t * [ `Straight | `Reversed ] =
try
let operands =
List.map
(fun t ->
match Marked.unmark t with TLit tl -> tl | _ -> raise Not_found)
operands
in
resolve_overload_aux (Marked.unmark op) operands
with Not_found ->
Errors.raise_multispanned_error
((None, Marked.get_mark op)
:: List.map
(fun ty ->
( Some
(Format.asprintf "Type %a coming from expression:"
(Print.typ ctx) ty),
Marked.get_mark ty ))
operands)
"I don't know how to apply operator %a on types %a" Print.operator
(Marked.unmark op)
(Format.pp_print_list
~pp_sep:(fun ppf () -> Format.fprintf ppf " and@ ")
(Print.typ ctx))
operands
let overload_type ctx (op : ('a, overloaded) t Marked.pos) (operands : typ list)
: typ =
let rop = fst (resolve_overload ctx op operands) in
resolved_type (Marked.same_mark_as rop op)

View File

@ -0,0 +1,69 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
(** {1 Catala operator utilities} *)
open Catala_utils
open Definitions
include module type of Definitions.Op
val equal : ('a1, 'k1) t -> ('a2, 'k2) t -> bool
val compare : ('a1, 'k1) t -> ('a2, 'k2) t -> int
val name : ('a, 'k) t -> string
(** Returns the operator name as a valid ident starting with a lowercase
character. This is different from Print.operator which returns operator
symbols, e.g. [+$]. *)
val kind_dispatch :
polymorphic:((_ any, polymorphic) t -> 'b) ->
monomorphic:((_ any, monomorphic) t -> 'b) ->
?overloaded:((desugared, overloaded) t -> 'b) ->
?resolved:(([< scopelang | dcalc | lcalc ], resolved) t -> 'b) ->
('a, 'k) t ->
'b
(** Calls one of the supplied functions depending on the kind of the operator *)
val translate :
([< scopelang | dcalc | lcalc ], 'k) t ->
([< scopelang | dcalc | lcalc ], 'k) t
(** An identity function that allows translating an operator between different
passes that don't change operator types *)
(** {2 Getting the types of operators} *)
val monomorphic_type : ('a any, monomorphic) t Marked.pos -> typ
val resolved_type :
([< scopelang | dcalc | lcalc ], resolved) t Marked.pos -> typ
val overload_type :
decl_ctx -> (desugared, overloaded) t Marked.pos -> typ list -> typ
(** The type for typing overloads is different since the types of the operands
are required in advance.
@raise a detailed user error if no matching operator can be found *)
(** Polymorphic operators are typed directly within [Typing], since their types
may contain type variables that can't be expressed outside of it*)
(** {2 Overload handling} *)
val resolve_overload :
decl_ctx ->
(desugared, overloaded) t Marked.pos ->
typ list ->
([< scopelang | dcalc | lcalc ], resolved) t * [ `Straight | `Reversed ]

View File

@ -42,7 +42,7 @@ let base_type (fmt : Format.formatter) (s : string) : unit =
let punctuation (fmt : Format.formatter) (s : string) : unit =
Cli.format_with_style [ANSITerminal.cyan] fmt s
let operator (fmt : Format.formatter) (s : string) : unit =
let op_style (fmt : Format.formatter) (s : string) : unit =
Cli.format_with_style [ANSITerminal.green] fmt s
let lit_style (fmt : Format.formatter) (s : string) : unit =
@ -81,7 +81,7 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
| TTuple ts ->
Format.fprintf fmt "@[<hov 2>(%a)@]"
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " operator "*")
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ %a@ " op_style "*")
typ)
ts
| TStruct s -> (
@ -113,7 +113,7 @@ let rec typ (ctx : decl_ctx option) (fmt : Format.formatter) (ty : typ) : unit =
punctuation "]")
| TOption t -> Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "option" typ t
| TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" typ_with_parens t1 operator ""
Format.fprintf fmt "@[<hov 2>%a %a@ %a@]" typ_with_parens t1 op_style ""
typ t2
| TArray t1 ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" base_type "collection" typ t1
@ -137,38 +137,6 @@ let lit (type a) (fmt : Format.formatter) (l : a glit) : unit =
| LDate d -> lit_style fmt (Runtime.date_to_string d)
| LDuration d -> lit_style fmt (Runtime.duration_to_string d)
let op_kind (fmt : Format.formatter) (k : 'a op_kind) =
Format.fprintf fmt "%s"
(match k with
| KInt -> ""
| KRat -> "."
| KMoney -> "$"
| KDate -> "@"
| KDuration -> "^")
let binop (fmt : Format.formatter) (op : 'a binop) : unit =
operator fmt
(match op with
| Add k -> Format.asprintf "+%a" op_kind k
| Sub k -> Format.asprintf "-%a" op_kind k
| Mult k -> Format.asprintf "*%a" op_kind k
| Div k -> Format.asprintf "/%a" op_kind k
| And -> "&&"
| Or -> "||"
| Xor -> "xor"
| Eq -> "="
| Neq -> "!="
| Lt k -> Format.asprintf "%s%a" "<" op_kind k
| Lte k -> Format.asprintf "%s%a" "<=" op_kind k
| Gt k -> Format.asprintf "%s%a" ">" op_kind k
| Gte k -> Format.asprintf "%s%a" ">=" op_kind k
| Concat -> "++"
| Map -> "map"
| Filter -> "filter")
let ternop (fmt : Format.formatter) (op : ternop) : unit =
match op with Fold -> keyword fmt "fold"
let log_entry (fmt : Format.formatter) (entry : log_entry) : unit =
Format.fprintf fmt "@<2>%a"
(fun fmt -> function
@ -179,30 +147,98 @@ let log_entry (fmt : Format.formatter) (entry : log_entry) : unit =
Cli.format_with_style [ANSITerminal.green] fmt "")
entry
let unop (fmt : Format.formatter) (op : 'a unop) : unit =
let operator_to_string : type a k. (a, k) Op.t -> string = function
| Not -> "~"
| Length -> "length"
| IntToRat -> "int_to_rat"
| MoneyToRat -> "money_to_rat"
| RatToMoney -> "rat_to_money"
| GetDay -> "get_day"
| GetMonth -> "get_month"
| GetYear -> "get_year"
| FirstDayOfMonth -> "first_day_of_month"
| LastDayOfMonth -> "last_day_of_month"
| RoundMoney -> "round_money"
| RoundDecimal -> "round_decimal"
| Log _ -> "Log"
| Minus -> "-"
| Minus_int -> "-!"
| Minus_rat -> "-."
| Minus_mon -> "-$"
| Minus_dur -> "-^"
| And -> "&&"
| Or -> "||"
| Xor -> "xor"
| Eq -> "="
| Map -> "map"
| Concat -> "++"
| Filter -> "filter"
| Add -> "+"
| Add_int_int -> "+!"
| Add_rat_rat -> "+."
| Add_mon_mon -> "+$"
| Add_dat_dur -> "+@"
| Add_dur_dur -> "+^"
| Sub -> "-"
| Sub_int_int -> "-!"
| Sub_rat_rat -> "-."
| Sub_mon_mon -> "-$"
| Sub_dat_dat -> "-@"
| Sub_dat_dur -> "-@^"
| Sub_dur_dur -> "-^"
| Mult -> "*"
| Mult_int_int -> "*!"
| Mult_rat_rat -> "*."
| Mult_mon_rat -> "*$"
| Mult_dur_int -> "*^"
| Div -> "/"
| Div_int_int -> "/!"
| Div_rat_rat -> "/."
| Div_mon_mon -> "/$"
| Div_mon_rat -> "/$."
| Lt -> "<"
| Lt_int_int -> "<!"
| Lt_rat_rat -> "<."
| Lt_mon_mon -> "<$"
| Lt_dur_dur -> "<^"
| Lt_dat_dat -> "<@"
| Lte -> "<="
| Lte_int_int -> "<=!"
| Lte_rat_rat -> "<=."
| Lte_mon_mon -> "<=$"
| Lte_dur_dur -> "<=^"
| Lte_dat_dat -> "<=@"
| Gt -> ">"
| Gt_int_int -> ">!"
| Gt_rat_rat -> ">."
| Gt_mon_mon -> ">$"
| Gt_dur_dur -> ">^"
| Gt_dat_dat -> ">@"
| Gte -> ">="
| Gte_int_int -> ">=!"
| Gte_rat_rat -> ">=."
| Gte_mon_mon -> ">=$"
| Gte_dur_dur -> ">=^"
| Gte_dat_dat -> ">=@"
| Eq_int_int -> "=!"
| Eq_rat_rat -> "=."
| Eq_mon_mon -> "=$"
| Eq_dur_dur -> "=^"
| Eq_dat_dat -> "=@"
| Fold -> "fold"
let operator (type k) (fmt : Format.formatter) (op : ('a, k) operator) : unit =
match op with
| Minus _ -> Format.pp_print_string fmt "-"
| Not -> Format.pp_print_string fmt "~"
| Log (entry, infos) ->
Format.fprintf fmt "log@[<hov 2>[%a|%a]@]" log_entry entry
Format.fprintf fmt "%a@[<hov 2>[%a|%a]@]" op_style "log" log_entry entry
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt ".")
(fun fmt info -> Uid.MarkedString.format fmt info))
infos
| Length -> Format.pp_print_string fmt "length"
| IntToRat -> Format.pp_print_string fmt "int_to_rat"
| MoneyToRat -> Format.pp_print_string fmt "money_to_rat"
| RatToMoney -> Format.pp_print_string fmt "rat_to_money"
| GetDay -> Format.pp_print_string fmt "get_day"
| GetMonth -> Format.pp_print_string fmt "get_month"
| GetYear -> Format.pp_print_string fmt "get_year"
| FirstDayOfMonth -> Format.pp_print_string fmt "first_day_of_month"
| LastDayOfMonth -> Format.pp_print_string fmt "last_day_of_month"
| RoundMoney -> Format.pp_print_string fmt "round_money"
| RoundDecimal -> Format.pp_print_string fmt "round_decimal"
| op -> Format.fprintf fmt "%a" op_style (operator_to_string op)
let except (fmt : Format.formatter) (exn : except) : unit =
operator fmt
op_style fmt
(match exn with
| EmptyError -> "EmptyError"
| ConflictError -> "ConflictError"
@ -279,16 +315,16 @@ let rec expr_aux :
Format.fprintf fmt "%a%a%a %a%a" punctuation "(" var x punctuation
":" (typ ctx) tau punctuation ")"))
xs_tau punctuation "" expr body
| EApp { f = EOp (Binop ((Map | Filter) as op)), _; args = [arg1; arg2] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" binop op with_parens arg1
| EApp { f = EOp { op = (Map | Filter) as op; _ }, _; args = [arg1; arg2] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" operator op with_parens arg1
with_parens arg2
| EApp { f = EOp (Binop op), _; args = [arg1; arg2] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" with_parens arg1 binop op
| EApp { f = EOp { op; _ }, _; args = [arg1; arg2] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@]" with_parens arg1 operator op
with_parens arg2
| EApp { f = EOp (Unop (Log _)), _; args = [arg1] } when not debug ->
| EApp { f = EOp { op = Log _; _ }, _; args = [arg1] } when not debug ->
expr fmt arg1
| EApp { f = EOp (Unop op), _; args = [arg1] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" unop op with_parens arg1
| EApp { f = EOp { op; _ }, _; args = [arg1] } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" operator op with_parens arg1
| EApp { f; args } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@]" expr f
(Format.pp_print_list
@ -298,9 +334,7 @@ let rec expr_aux :
| EIfThenElse { cond; etrue; efalse } ->
Format.fprintf fmt "@[<hov 2>%a@ %a@ %a@ %a@ %a@ %a@]" keyword "if" expr
cond keyword "then" expr etrue keyword "else" expr efalse
| EOp (Ternop op) -> ternop fmt op
| EOp (Binop op) -> binop fmt op
| EOp (Unop op) -> unop fmt op
| EOp { op; _ } -> operator fmt op
| EDefault { excepts; just; cons } ->
if List.length excepts = 0 then
Format.fprintf fmt "@[<hov 2>%a%a@ %a@ %a%a@]" punctuation "" expr just
@ -313,7 +347,7 @@ let rec expr_aux :
excepts punctuation "|" expr just punctuation "" expr cons punctuation
""
| EErrorOnEmpty e' ->
Format.fprintf fmt "%a@ %a" operator "error_empty" with_parens e'
Format.fprintf fmt "%a@ %a" op_style "error_empty" with_parens e'
| EAssert e' ->
Format.fprintf fmt "@[<hov 2>%a@ %a%a%a@]" keyword "assert" punctuation "("
expr e' punctuation ")"

View File

@ -24,7 +24,7 @@ open Definitions
val base_type : Format.formatter -> string -> unit
val keyword : Format.formatter -> string -> unit
val punctuation : Format.formatter -> string -> unit
val operator : Format.formatter -> string -> unit
val op_style : Format.formatter -> string -> unit
val lit_style : Format.formatter -> string -> unit
(** {1 Formatters} *)
@ -35,11 +35,8 @@ val tlit : Format.formatter -> typ_lit -> unit
val location : Format.formatter -> 'a glocation -> unit
val typ : decl_ctx -> Format.formatter -> typ -> unit
val lit : Format.formatter -> 'a glit -> unit
val op_kind : Format.formatter -> 'a any op_kind -> unit
val binop : Format.formatter -> 'a any binop -> unit
val ternop : Format.formatter -> ternop -> unit
val operator : Format.formatter -> ('a any, 'k) operator -> unit
val log_entry : Format.formatter -> log_entry -> unit
val unop : Format.formatter -> 'a any unop -> unit
val except : Format.formatter -> except -> unit
val var : Format.formatter -> 'e Var.t -> unit
val var_debug : Format.formatter -> 'e Var.t -> unit

View File

@ -16,6 +16,8 @@
include Definitions
module Var = Var
module Type = Type
module Operator = Operator
module Expr = Expr
module Scope = Scope
module Program = Program

View File

@ -0,0 +1,87 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
open Catala_utils
open Definitions
type t = typ
let equal_tlit l1 l2 = l1 = l2
let compare_tlit l1 l2 = Stdlib.compare l1 l2
let rec equal ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> equal_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> equal t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> equal t1 t2 && equal t1' t2'
| TArray t1, TArray t2 -> equal t1 t2
| TAny, TAny -> true
| ( ( TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _
| TArray _ | TAny ),
_ ) ->
false
and equal_list tys1 tys2 =
try List.for_all2 equal tys1 tys2 with Invalid_argument _ -> false
(* Similar to [equal], but allows TAny holes *)
let rec unifiable ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TAny, _ | _, TAny -> true
| TLit l1, TLit l2 -> equal_tlit l1 l2
| TTuple tys1, TTuple tys2 -> unifiable_list tys1 tys2
| TStruct n1, TStruct n2 -> StructName.equal n1 n2
| TEnum n1, TEnum n2 -> EnumName.equal n1 n2
| TOption t1, TOption t2 -> unifiable t1 t2
| TArrow (t1, t1'), TArrow (t2, t2') -> unifiable t1 t2 && unifiable t1' t2'
| TArray t1, TArray t2 -> unifiable t1 t2
| ( (TLit _ | TTuple _ | TStruct _ | TEnum _ | TOption _ | TArrow _ | TArray _),
_ ) ->
false
and unifiable_list tys1 tys2 =
try List.for_all2 unifiable tys1 tys2 with Invalid_argument _ -> false
let rec compare ty1 ty2 =
match Marked.unmark ty1, Marked.unmark ty2 with
| TLit l1, TLit l2 -> compare_tlit l1 l2
| TTuple tys1, TTuple tys2 -> List.compare compare tys1 tys2
| TStruct n1, TStruct n2 -> StructName.compare n1 n2
| TEnum en1, TEnum en2 -> EnumName.compare en1 en2
| TOption t1, TOption t2 -> compare t1 t2
| TArrow (a1, b1), TArrow (a2, b2) -> (
match compare a1 a2 with 0 -> compare b1 b2 | n -> n)
| TArray t1, TArray t2 -> compare t1 t2
| TAny, TAny -> 0
| TLit _, _ -> -1
| _, TLit _ -> 1
| TTuple _, _ -> -1
| _, TTuple _ -> 1
| TStruct _, _ -> -1
| _, TStruct _ -> 1
| TEnum _, _ -> -1
| _, TEnum _ -> 1
| TOption _, _ -> -1
| _, TOption _ -> 1
| TArrow _, _ -> -1
| _, TArrow _ -> 1
| TArray _, _ -> -1
| _, TArray _ -> 1
let rec arrow_return = function TArrow (_, b), _ -> arrow_return b | t -> t

View File

@ -0,0 +1,27 @@
(* This file is part of the Catala compiler, a specification language for tax
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
Louis Gesbert <louis.gesbert@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not
use this file except in compliance with the License. You may obtain a copy of
the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
type t = Definitions.typ
val equal : t -> t -> bool
val equal_list : t list -> t list -> bool
val compare : t -> t -> int
val unifiable : t -> t -> bool
(** Similar to [equal], but allows TAny holes *)
val arrow_return : t -> t
(** Returns the last member in nested [TArrow] types *)

View File

@ -225,73 +225,48 @@ let lit_type (type a) (lit : a A.glit) : naked_typ =
| LUnit -> TLit TUnit
| LEmptyError -> TAny (Any.fresh ())
(** Operators have a single type, instead of being polymorphic with constraints.
This allows us to have a simpler type system, while we argue the syntactic
burden of operator annotations helps the programmer visualize the type flow
in the code. *)
let op_type (op : 'a A.operator Marked.pos) : unionfind_typ =
(** [op_type] and [resolve_overload] are a bit similar, and work on disjoint
sets of operators. However, their assumptions are different so we keep the
functions separate. In particular [resolve_overloads] requires its argument
types to be known in advance. *)
let polymorphic_op_type (op : ('a, Operator.polymorphic) A.operator Marked.pos)
: unionfind_typ =
let open Operator in
let pos = Marked.get_mark op in
let bt = UnionFind.make (TLit TBool, pos) in
let it = UnionFind.make (TLit TInt, pos) in
let rt = UnionFind.make (TLit TRat, pos) in
let mt = UnionFind.make (TLit TMoney, pos) in
let dut = UnionFind.make (TLit TDuration, pos) in
let dat = UnionFind.make (TLit TDate, pos) in
let any = UnionFind.make (TAny (Any.fresh ()), pos) in
let array_any = UnionFind.make (TArray any, pos) in
let any2 = UnionFind.make (TAny (Any.fresh ()), pos) in
let array_any2 = UnionFind.make (TArray any2, pos) in
let arr x y = UnionFind.make (TArrow (x, y), pos) in
match Marked.unmark op with
| A.Ternop A.Fold ->
arr (arr any2 (arr any any2)) (arr any2 (arr array_any any2))
| A.Binop (A.And | A.Or | A.Xor) -> arr bt (arr bt bt)
| A.Binop (A.Add KInt | A.Sub KInt | A.Mult KInt | A.Div KInt) ->
arr it (arr it it)
| A.Binop (A.Add KRat | A.Sub KRat | A.Mult KRat | A.Div KRat) ->
arr rt (arr rt rt)
| A.Binop (A.Add KMoney | A.Sub KMoney) -> arr mt (arr mt mt)
| A.Binop (A.Add KDuration | A.Sub KDuration) -> arr dut (arr dut dut)
| A.Binop (A.Sub KDate) -> arr dat (arr dat dut)
| A.Binop (A.Add KDate) -> arr dat (arr dut dat)
| A.Binop (A.Mult KDuration) -> arr dut (arr it dut)
| A.Binop (A.Div KMoney) -> arr mt (arr mt rt)
| A.Binop (A.Mult KMoney) -> arr mt (arr rt mt)
| A.Binop (A.Lt KInt | A.Lte KInt | A.Gt KInt | A.Gte KInt) ->
arr it (arr it bt)
| A.Binop (A.Lt KRat | A.Lte KRat | A.Gt KRat | A.Gte KRat) ->
arr rt (arr rt bt)
| A.Binop (A.Lt KMoney | A.Lte KMoney | A.Gt KMoney | A.Gte KMoney) ->
arr mt (arr mt bt)
| A.Binop (A.Lt KDate | A.Lte KDate | A.Gt KDate | A.Gte KDate) ->
arr dat (arr dat bt)
| A.Binop (A.Lt KDuration | A.Lte KDuration | A.Gt KDuration | A.Gte KDuration)
->
arr dut (arr dut bt)
| A.Binop (A.Eq | A.Neq) -> arr any (arr any bt)
| A.Binop A.Map -> arr (arr any any2) (arr array_any array_any2)
| A.Binop A.Filter -> arr (arr any bt) (arr array_any array_any)
| A.Binop A.Concat -> arr array_any (arr array_any array_any)
| A.Unop (A.Minus KInt) -> arr it it
| A.Unop (A.Minus KRat) -> arr rt rt
| A.Unop (A.Minus KMoney) -> arr mt mt
| A.Unop (A.Minus KDuration) -> arr dut dut
| A.Unop A.Not -> arr bt bt
| A.Unop (A.Log (A.PosRecordIfTrueBool, _)) -> arr bt bt
| A.Unop (A.Log _) -> arr any any
| A.Unop A.Length -> arr array_any it
| A.Unop A.GetDay -> arr dat it
| A.Unop A.GetMonth -> arr dat it
| A.Unop A.GetYear -> arr dat it
| A.Unop A.FirstDayOfMonth -> arr dat dat
| A.Unop A.LastDayOfMonth -> arr dat dat
| A.Unop A.RoundMoney -> arr mt mt
| A.Unop A.RoundDecimal -> arr rt rt
| A.Unop A.IntToRat -> arr it rt
| A.Unop A.MoneyToRat -> arr mt rt
| A.Unop A.RatToMoney -> arr rt mt
| Binop (Mult KDate) | Binop (Div (KDate | KDuration)) | Unop (Minus KDate) ->
Errors.raise_spanned_error pos "This operator is not available!"
let any = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in
let any2 = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in
let bt = lazy (UnionFind.make (TLit TBool, pos)) in
let it = lazy (UnionFind.make (TLit TInt, pos)) in
let array a = lazy (UnionFind.make (TArray (Lazy.force a), pos)) in
let ( @-> ) x y =
lazy (UnionFind.make (TArrow (Lazy.force x, Lazy.force y), pos))
in
let ty =
match Marked.unmark op with
| Fold -> (any2 @-> any @-> any2) @-> any2 @-> array any @-> any2
| Eq -> any @-> any @-> bt
| Map -> (any @-> any2) @-> array any @-> array any2
| Filter -> (any @-> bt) @-> array any @-> array any
| Concat -> array any @-> array any @-> array any
| Log (PosRecordIfTrueBool, _) -> bt @-> bt
| Log _ -> any @-> any
| Length -> array any @-> it
in
Lazy.force ty
let resolve_overload_ret_type
(ctx : A.decl_ctx)
e
(op : ('a A.any, Operator.overloaded) A.operator)
tys : unionfind_typ =
let op_ty =
Operator.overload_type ctx
(Marked.mark (Expr.pos e) op)
(List.map (typ_to_ast ~unsafe:true) tys)
(* We use [unsafe] because the error is caught below *)
in
ast_to_typ (Type.arrow_return op_ty)
(** {1 Double-directed typing} *)
@ -605,24 +580,41 @@ and typecheck_expr_top_down :
let body' = typecheck_expr_top_down ctx env t_ret body in
let binder' = Bindlib.bind_mvar xs' (Expr.Box.lift body') in
Expr.eabs binder' (List.map typ_to_ast tau_args) mark
| A.EApp { f = (EOp _, _) as e1; args } ->
(* Same as EApp, but the typing order is different to help with
disambiguation: - type of the operator is extracted first (to figure
linked type vars between arguments) - arguments are typed right-to-left,
because our operators with function args always have the functions first,
and the argument types of those functions can always be inferred from the
later operator arguments *)
let t_args = List.map (fun _ -> unionfind (TAny (Any.fresh ()))) args in
| A.EApp { f = (EOp { op; tys }, _) as e1; args } ->
let t_args = List.map ast_to_typ tys in
let t_func =
List.fold_right
(fun t_arg acc -> unionfind (TArrow (t_arg, acc)))
t_args tau
in
let e1' = typecheck_expr_top_down ctx env t_func e1 in
let args' =
List.rev_map2
(typecheck_expr_top_down ctx env)
(List.rev t_args) (List.rev args)
let e1', args' =
Operator.kind_dispatch op
~polymorphic:(fun _ ->
(* Type the operator first, then right-to-left: polymorphic operators
are required to allow the resolution of all type variables this
way *)
let e1' = typecheck_expr_top_down ctx env t_func e1 in
let args' =
List.rev_map2
(typecheck_expr_top_down ctx env)
(List.rev t_args) (List.rev args)
in
e1', args')
~overloaded:(fun _ ->
(* Typing the arguments first is required to resolve the operator *)
let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
let e1' = typecheck_expr_top_down ctx env t_func e1 in
e1', args')
~monomorphic:(fun _ ->
(* Here it doesn't matter but may affect the error messages *)
let e1' = typecheck_expr_top_down ctx env t_func e1 in
let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
e1', args')
~resolved:(fun _ ->
(* This case should not fail *)
let e1' = typecheck_expr_top_down ctx env t_func e1 in
let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
e1', args')
in
Expr.eapp e1' args' context_mark
| A.EApp { f = e1; args } ->
@ -638,7 +630,35 @@ and typecheck_expr_top_down :
let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
let e1' = typecheck_expr_top_down ctx env t_func e1 in
Expr.eapp e1' args' context_mark
| A.EOp op -> Expr.eop op (uf_mark (op_type (Marked.mark pos_e op)))
| A.EOp { op; tys } ->
let tys' = List.map ast_to_typ tys in
let t_ret = unionfind (TAny (Any.fresh ())) in
let t_func =
List.fold_right
(fun t_arg acc -> unionfind (TArrow (t_arg, acc)))
tys' t_ret
in
unify ctx e t_func tau;
let tys, mark =
Operator.kind_dispatch op
~polymorphic:(fun op ->
tys, uf_mark (polymorphic_op_type (Marked.mark pos_e op)))
~monomorphic:(fun op ->
let mark =
uf_mark
(ast_to_typ (Operator.monomorphic_type (Marked.mark pos_e op)))
in
List.map typ_to_ast tys', mark)
~overloaded:(fun op ->
unify ctx e t_ret (resolve_overload_ret_type ctx e op tys');
List.map typ_to_ast tys', { uf = t_func; pos = pos_e })
~resolved:(fun op ->
let mark =
uf_mark (ast_to_typ (Operator.resolved_type (Marked.mark pos_e op)))
in
List.map typ_to_ast tys', mark)
in
Expr.eop op tys mark
| A.EDefault { excepts; just; cons } ->
let cons' = typecheck_expr_top_down ctx env tau cons in
let just' =

View File

@ -246,7 +246,7 @@ type match_case_pattern =
name = "match_case_pattern_iter";
}]
type op_kind = KInt | KDec | KMoney | KDate | KDuration
type op_kind = KPoly | KInt | KDec | KMoney | KDate | KDuration
[@@deriving
visitors { variety = "map"; name = "op_kind_map"; nude = true },
visitors { variety = "iter"; name = "op_kind_iter"; nude = true }]
@ -387,9 +387,12 @@ type literal =
type aggregate_func =
| AggregateSum of primitive_typ
(* it would be nice to remove the need for specifying the type here like for
extremums, but we need an additionl overload for "neutral element for
addition across types" *)
| AggregateCount
| AggregateExtremum of bool * primitive_typ * expression Marked.pos
| AggregateArgExtremum of bool * primitive_typ * expression Marked.pos
| AggregateExtremum of bool * primitive_typ option * expression Marked.pos
| AggregateArgExtremum of bool * primitive_typ option * expression Marked.pos
and collection_op =
| Exists

View File

@ -263,6 +263,9 @@ module R = Re.Pcre
#ifndef MR_INTERNAL
#define MR_INTERNAL MS_INTERNAL
#endif
#ifndef MR_MONEY_OP_SUFFIX
#define MR_MONEY_OP_SUFFIX MS_MONEY_OP_SUFFIX
#endif
let token_list : (string * token) list =
[
@ -365,6 +368,18 @@ let space_plus = [%sedlex.regexp? Plus white_space]
(** Regexp matching white space but not newlines *)
let hspace = [%sedlex.regexp? Sub (white_space, Chars "\n\r")]
(** Operator explicit typing suffix chars *)
let op_kind_re = [%sedlex.regexp? "" | MR_MONEY_OP_SUFFIX | Chars "!.@^"]
let op_kind = function
| "" -> Ast.KPoly
| "!" -> Ast.KInt
| "." -> Ast.KDec
| MS_MONEY_OP_SUFFIX -> Ast.KMoney
| "@" -> Ast.KDate
| "^" -> Ast.KDuration
| _ -> invalid_arg "op_kind"
(** Main lexing function used in code blocks *)
let rec lex_code (lexbuf : lexbuf) : token =
let prev_lexeme = Utf8.lexeme lexbuf in
@ -629,117 +644,38 @@ let rec lex_code (lexbuf : lexbuf) : token =
L.update_acc lexbuf;
DECIMAL_LITERAL
(dec_parts 1, dec_parts 2)
| "<=@" ->
| "<=", op_kind_re ->
let k = op_kind (String.remove_prefix ~prefix:"<=" (Utf8.lexeme lexbuf)) in
L.update_acc lexbuf;
LESSER_EQUAL_DATE
| "<@" ->
LESSER_EQUAL k
| "<", op_kind_re ->
let k = op_kind (String.remove_prefix ~prefix:"<" (Utf8.lexeme lexbuf)) in
L.update_acc lexbuf;
LESSER_DATE
| ">=@" ->
LESSER k
| ">=", op_kind_re ->
let k = op_kind (String.remove_prefix ~prefix:">=" (Utf8.lexeme lexbuf)) in
L.update_acc lexbuf;
GREATER_EQUAL_DATE
| ">@" ->
GREATER_EQUAL k
| ">", op_kind_re ->
let k = op_kind (String.remove_prefix ~prefix:">" (Utf8.lexeme lexbuf)) in
L.update_acc lexbuf;
GREATER_DATE
| "-@" ->
GREATER k
| "-", op_kind_re ->
let k = op_kind (String.remove_prefix ~prefix:"-" (Utf8.lexeme lexbuf)) in
L.update_acc lexbuf;
MINUSDATE
| "+@" ->
MINUS k
| "+", op_kind_re ->
let k = op_kind (String.remove_prefix ~prefix:"+" (Utf8.lexeme lexbuf)) in
L.update_acc lexbuf;
PLUSDATE
| "<=^" ->
PLUS k
| "*", op_kind_re ->
let k = op_kind (String.remove_prefix ~prefix:"*" (Utf8.lexeme lexbuf)) in
L.update_acc lexbuf;
LESSER_EQUAL_DURATION
| "<^" ->
MULT k
| '/', op_kind_re ->
let k = op_kind (String.remove_prefix ~prefix:"/" (Utf8.lexeme lexbuf)) in
L.update_acc lexbuf;
LESSER_DURATION
| ">=^" ->
L.update_acc lexbuf;
GREATER_EQUAL_DURATION
| ">^" ->
L.update_acc lexbuf;
GREATER_DURATION
| "+^" ->
L.update_acc lexbuf;
PLUSDURATION
| "-^" ->
L.update_acc lexbuf;
MINUSDURATION
| "*^" ->
L.update_acc lexbuf;
MULDURATION
| "<=", MR_MONEY_OP_SUFFIX ->
L.update_acc lexbuf;
LESSER_EQUAL_MONEY
| '<', MR_MONEY_OP_SUFFIX ->
L.update_acc lexbuf;
LESSER_MONEY
| ">=", MR_MONEY_OP_SUFFIX ->
L.update_acc lexbuf;
GREATER_EQUAL_MONEY
| '>', MR_MONEY_OP_SUFFIX ->
L.update_acc lexbuf;
GREATER_MONEY
| '+', MR_MONEY_OP_SUFFIX ->
L.update_acc lexbuf;
PLUSMONEY
| '-', MR_MONEY_OP_SUFFIX ->
L.update_acc lexbuf;
MINUSMONEY
| '*', MR_MONEY_OP_SUFFIX ->
L.update_acc lexbuf;
MULTMONEY
| '/', MR_MONEY_OP_SUFFIX ->
L.update_acc lexbuf;
DIVMONEY
| "<=." ->
L.update_acc lexbuf;
LESSER_EQUAL_DEC
| "<." ->
L.update_acc lexbuf;
LESSER_DEC
| ">=." ->
L.update_acc lexbuf;
GREATER_EQUAL_DEC
| ">." ->
L.update_acc lexbuf;
GREATER_DEC
| "+." ->
L.update_acc lexbuf;
PLUSDEC
| "-." ->
L.update_acc lexbuf;
MINUSDEC
| "*." ->
L.update_acc lexbuf;
MULTDEC
| "/." ->
L.update_acc lexbuf;
DIVDEC
| "<=" ->
L.update_acc lexbuf;
LESSER_EQUAL
| '<' ->
L.update_acc lexbuf;
LESSER
| ">=" ->
L.update_acc lexbuf;
GREATER_EQUAL
| '>' ->
L.update_acc lexbuf;
GREATER
| '+' ->
L.update_acc lexbuf;
PLUS
| '-' ->
L.update_acc lexbuf;
MINUS
| '*' ->
L.update_acc lexbuf;
MULT
| '/' ->
L.update_acc lexbuf;
DIV
DIV k
| "!=" ->
L.update_acc lexbuf;
NOT_EQUAL

View File

@ -69,9 +69,9 @@ let raise_lexer_error (loc : Pos.t) (token : string) =
let token_list_language_agnostic : (string * token) list =
[
".", DOT;
"<=", LESSER_EQUAL;
">=", GREATER_EQUAL;
">", GREATER;
"<=", LESSER_EQUAL KPoly;
">=", GREATER_EQUAL KPoly;
">", GREATER KPoly;
"!=", NOT_EQUAL;
"=", EQUAL;
"(", LPAREN;
@ -80,10 +80,10 @@ let token_list_language_agnostic : (string * token) list =
"}", RBRACKET;
"{", LSQUARE;
"}", RSQUARE;
"+", PLUS;
"-", MINUS;
"*", MULT;
"/", DIV;
"+", PLUS KPoly;
"-", MINUS KPoly;
"*", MULT KPoly;
"/", DIV KPoly;
"|", VERTICAL;
":", COLON;
";", SEMICOLON;

View File

@ -93,7 +93,7 @@
(* Specific delimiters *)
#define MR_MONEY_OP_SUFFIX '$'
#define MS_MONEY_OP_SUFFIX "$"
#define MC_DECIMAL_SEPARATOR '.'
#define MR_MONEY_PREFIX '$', Star hspace
#define MR_MONEY_DELIM ','

View File

@ -114,7 +114,9 @@
(* Specific delimiters *)
#define MR_MONEY_OP_SUFFIX 0x20AC (* The euro sign *)
#define MS_MONEY_OP_SUFFIX ""
#define MR_MONEY_OP_SUFFIX 0x20AC
(* The euro sign *)
#define MC_DECIMAL_SEPARATOR ','
#define MR_MONEY_PREFIX ""
#define MR_MONEY_DELIM ' '

View File

@ -102,7 +102,7 @@
(* Specific delimiters *)
#define MR_MONEY_OP_SUFFIX '$'
#define MS_MONEY_OP_SUFFIX "$"
#define MC_DECIMAL_SEPARATOR '.'
#define MR_MONEY_PREFIX ""
#define MR_MONEY_DELIM ','

File diff suppressed because it is too large Load Diff

View File

@ -149,41 +149,25 @@ literal:
| FALSE { (LBool false, Pos.from_lpos $sloc) }
compare_op:
| LESSER { (Lt KInt, Pos.from_lpos $sloc) }
| LESSER_EQUAL { (Lte KInt, Pos.from_lpos $sloc) }
| GREATER { (Gt KInt, Pos.from_lpos $sloc) }
| GREATER_EQUAL { (Gte KInt, Pos.from_lpos $sloc) }
| LESSER_DEC { (Lt KDec, Pos.from_lpos $sloc) }
| LESSER_EQUAL_DEC { (Lte KDec, Pos.from_lpos $sloc) }
| GREATER_DEC { (Gt KDec, Pos.from_lpos $sloc) }
| GREATER_EQUAL_DEC { (Gte KDec, Pos.from_lpos $sloc) }
| LESSER_MONEY { (Lt KMoney, Pos.from_lpos $sloc) }
| LESSER_EQUAL_MONEY { (Lte KMoney, Pos.from_lpos $sloc) }
| GREATER_MONEY { (Gt KMoney, Pos.from_lpos $sloc) }
| GREATER_EQUAL_MONEY { (Gte KMoney, Pos.from_lpos $sloc) }
| LESSER_DATE { (Lt KDate, Pos.from_lpos $sloc) }
| LESSER_EQUAL_DATE { (Lte KDate, Pos.from_lpos $sloc) }
| GREATER_DATE { (Gt KDate, Pos.from_lpos $sloc) }
| GREATER_EQUAL_DATE { (Gte KDate, Pos.from_lpos $sloc) }
| LESSER_DURATION { (Lt KDuration, Pos.from_lpos $sloc) }
| LESSER_EQUAL_DURATION { (Lte KDuration, Pos.from_lpos $sloc) }
| GREATER_DURATION { (Gt KDuration, Pos.from_lpos $sloc) }
| GREATER_EQUAL_DURATION { (Gte KDuration, Pos.from_lpos $sloc) }
| LESSER { (Lt KPoly, Pos.from_lpos $sloc) }
| LESSER_EQUAL { (Lte KPoly, Pos.from_lpos $sloc) }
| GREATER { (Gt KPoly, Pos.from_lpos $sloc) }
| GREATER_EQUAL { (Gte KPoly, Pos.from_lpos $sloc) }
| EQUAL { (Eq, Pos.from_lpos $sloc) }
| NOT_EQUAL { (Neq, Pos.from_lpos $sloc) }
aggregate_func:
| CONTENT MAXIMUM t = typ_base INIT init = primitive_expression {
(Aggregate (AggregateArgExtremum (true, Marked.unmark t, init)), Pos.from_lpos $sloc)
| CONTENT MAXIMUM t = option(typ_base) INIT init = primitive_expression {
(Aggregate (AggregateArgExtremum (true, Option.map Marked.unmark t, init)), Pos.from_lpos $sloc)
}
| CONTENT MINIMUM t = typ_base INIT init = primitive_expression {
(Aggregate (AggregateArgExtremum (false, Marked.unmark t, init)), Pos.from_lpos $sloc)
| CONTENT MINIMUM t = option(typ_base) INIT init = primitive_expression {
(Aggregate (AggregateArgExtremum (false, Option.map Marked.unmark t, init)), Pos.from_lpos $sloc)
}
| MAXIMUM t = typ_base INIT init = primitive_expression {
(Aggregate (AggregateExtremum (true, Marked.unmark t, init)), Pos.from_lpos $sloc)
| MAXIMUM t = option(typ_base) INIT init = primitive_expression {
(Aggregate (AggregateExtremum (true, Option.map Marked.unmark t, init)), Pos.from_lpos $sloc)
}
| MINIMUM t = typ_base INIT init = primitive_expression {
(Aggregate (AggregateExtremum (false, Marked.unmark t, init)), Pos.from_lpos $sloc)
| MINIMUM t = option(typ_base) INIT init = primitive_expression {
(Aggregate (AggregateExtremum (false, Option.map Marked.unmark t, init)), Pos.from_lpos $sloc)
}
| SUM t = typ_base { (Aggregate (AggregateSum (Marked.unmark t)), Pos.from_lpos $sloc) }
| CARDINAL { (Aggregate AggregateCount, Pos.from_lpos $sloc) }
@ -216,23 +200,15 @@ base_expression:
unop:
| NOT { (Not, Pos.from_lpos $sloc) }
| MINUS { (Minus KInt, Pos.from_lpos $sloc) }
| MINUSDEC { (Minus KDec, Pos.from_lpos $sloc) }
| MINUSMONEY { (Minus KMoney, Pos.from_lpos $sloc) }
| MINUSDURATION { (Minus KDuration, Pos.from_lpos $sloc) }
| k = MINUS { (Minus k, Pos.from_lpos $sloc) }
unop_expression:
| e = base_expression { e }
| op = unop e = unop_expression { (Unop (op, e), Pos.from_lpos $sloc) }
mult_op:
| MULT { (Mult KInt, Pos.from_lpos $sloc) }
| DIV { (Div KInt, Pos.from_lpos $sloc) }
| MULTDEC { (Mult KDec, Pos.from_lpos $sloc) }
| DIVDEC { (Div KDec, Pos.from_lpos $sloc) }
| MULTMONEY { (Mult KMoney, Pos.from_lpos $sloc) }
| DIVMONEY { (Div KMoney, Pos.from_lpos $sloc) }
| MULDURATION { (Mult KDuration, Pos.from_lpos $sloc) }
| k = MULT { (Mult k, Pos.from_lpos $sloc) }
| k = DIV { (Div k, Pos.from_lpos $sloc) }
mult_expression:
| e = unop_expression { e }
@ -241,16 +217,8 @@ mult_expression:
}
sum_op:
| PLUSDURATION { (Add KDuration, Pos.from_lpos $sloc) }
| MINUSDURATION { (Sub KDuration, Pos.from_lpos $sloc) }
| PLUSDATE { (Add KDate, Pos.from_lpos $sloc) }
| MINUSDATE { (Sub KDate, Pos.from_lpos $sloc) }
| PLUSMONEY { (Add KMoney, Pos.from_lpos $sloc) }
| MINUSMONEY { (Sub KMoney, Pos.from_lpos $sloc) }
| PLUSDEC { (Add KDec, Pos.from_lpos $sloc) }
| MINUSDEC { (Sub KDec, Pos.from_lpos $sloc) }
| PLUS { (Add KInt, Pos.from_lpos $sloc) }
| MINUS { (Sub KInt, Pos.from_lpos $sloc) }
| k = PLUS { (Add k, Pos.from_lpos $sloc) }
| k = MINUS { (Sub k, Pos.from_lpos $sloc) }
| PLUSPLUS { (Concat, Pos.from_lpos $sloc) }
sum_expression:

View File

@ -40,18 +40,11 @@
%token COLON ALT DATA VERTICAL
%token OF INTEGER COLLECTION CONTAINS
%token RULE CONDITION DEFINED_AS
%token LESSER GREATER LESSER_EQUAL GREATER_EQUAL
%token LESSER_DEC GREATER_DEC LESSER_EQUAL_DEC GREATER_EQUAL_DEC
%token LESSER_MONEY GREATER_MONEY LESSER_EQUAL_MONEY GREATER_EQUAL_MONEY
%token LESSER_DATE GREATER_DATE LESSER_EQUAL_DATE GREATER_EQUAL_DATE
%token LESSER_DURATION GREATER_DURATION LESSER_EQUAL_DURATION GREATER_EQUAL_DURATION
%token<Ast.op_kind> LESSER GREATER LESSER_EQUAL GREATER_EQUAL
%token LET EXISTS IN SUCH THAT
%token DOT AND OR XOR LPAREN RPAREN EQUAL
%token CARDINAL ASSERTION FIXED BY YEAR MONTH DAY
%token PLUS MINUS MULT DIV
%token PLUSDEC MINUSDEC MULTDEC DIVDEC
%token PLUSMONEY MINUSMONEY MULTMONEY DIVMONEY
%token MINUSDATE PLUSDATE PLUSDURATION MINUSDURATION MULDURATION
%token<Ast.op_kind> PLUS MINUS MULT DIV
%token PLUSPLUS
%token MATCH WITH VARIES WITH_V WILDCARD
%token FOR ALL WE_HAVE INCREASING DECREASING

View File

@ -37,11 +37,28 @@ let conjunction (args : vc_return list) (mark : typed mark) : vc_return =
match args with hd :: tl -> hd, tl | [] -> (ELit (LBool true), mark), []
in
List.fold_left
(fun acc arg -> EApp { f = EOp (Binop And), mark; args = [arg; acc] }, mark)
(fun acc arg ->
( EApp
{
f =
( EOp
{
op = And;
tys = [TLit TBool, Expr.pos acc; TLit TBool, Expr.pos arg];
},
mark );
args = [arg; acc];
},
mark ))
acc list
let negation (arg : vc_return) (mark : typed mark) : vc_return =
EApp { f = EOp (Unop Not), mark; args = [arg] }, mark
( EApp
{
f = EOp { op = Not; tys = [TLit TBool, Expr.pos arg] }, mark;
args = [arg];
},
mark )
let disjunction (args : vc_return list) (mark : typed mark) : vc_return =
let acc, list =
@ -49,7 +66,18 @@ let disjunction (args : vc_return list) (mark : typed mark) : vc_return =
in
List.fold_left
(fun (acc : vc_return) arg ->
EApp { f = EOp (Binop Or), mark; args = [arg; acc] }, mark)
( EApp
{
f =
( EOp
{
op = Or;
tys = [TLit TBool, Expr.pos acc; TLit TBool, Expr.pos arg];
},
mark );
args = [arg; acc];
},
mark ))
acc list
(** [half_product \[a1,...,an\] \[b1,...,bm\] returns \[(a1,b1),...(a1,bn),...(an,b1),...(an,bm)\]] *)

View File

@ -111,7 +111,7 @@ let unique_name (v : 'e Var.t) : string =
let date_to_int (d : Runtime.date) : int =
(* Alternatively, could expose this from Runtime as a (noop) coercion, but
would allow to break abstraction more easily elsewhere *)
let period = Runtime.( -@ ) d base_day in
let period = Runtime.Oper.o_sub_dat_dat d base_day in
let y, m, d = Runtime.duration_to_years_months_days period in
assert (y = 0 && m = 0);
d
@ -124,7 +124,7 @@ let date_of_year (year : int) = Runtime.date_of_numbers year 1 1
defined here as Jan 1, 1900 **)
let nb_days_to_date (nb : int) : string =
Runtime.date_to_string
(Runtime.( +@ ) base_day (Runtime.duration_of_numbers 0 0 nb))
(Runtime.Oper.o_add_dat_dur base_day (Runtime.duration_of_numbers 0 0 nb))
(** [print_z3model_expr] pretty-prints the value [e] given by a Z3 model
according to the Catala type [ty], corresponding to [e] **)
@ -426,223 +426,181 @@ let is_leap_year = Runtime.is_leap_year
(** [translate_op] returns the Z3 expression corresponding to the application of
[op] to the arguments [args] **)
let rec translate_op (ctx : context) (op : dcalc operator) (args : 'm expr list)
: context * Expr.expr =
match op with
| Ternop _top ->
let _e1, _e2, _e3 =
match args with
| [e1; e2; e3] -> e1, e2, e3
| _ ->
Format.kasprintf failwith
"[Z3 encoding] Ill-formed ternary operator application: %a"
(Shared_ast.Expr.format ctx.ctx_decl)
(Shared_ast.Expr.eapp
(Shared_ast.Expr.eop op (Untyped { pos = Pos.no_pos }))
(List.map Shared_ast.Expr.untype args)
(Untyped { pos = Pos.no_pos })
|> Shared_ast.Expr.unbox)
in
let rec translate_op :
type k.
context -> (dcalc, k) operator -> 'm expr list -> context * Expr.expr =
fun ctx op args ->
let ill_formed () =
Format.kasprintf failwith
"[Z3 encoding] Ill-formed operator application: %a"
(Shared_ast.Expr.format ctx.ctx_decl)
(Shared_ast.Expr.eapp
(Shared_ast.Expr.eop op [] (Untyped { pos = Pos.no_pos }))
(List.map Shared_ast.Expr.untype args)
(Untyped { pos = Pos.no_pos })
|> Shared_ast.Expr.unbox)
in
let app f =
let ctx, args = List.fold_left_map translate_expr ctx args in
ctx, f ctx.ctx_z3 args
in
let app1 f =
app (fun ctx -> function [a] -> f ctx a | _ -> ill_formed ())
in
let app2 f =
app (fun ctx -> function [a; b] -> f ctx a b | _ -> ill_formed ())
in
match op, args with
| Fold, _ ->
failwith "[Z3 encoding] ternary operator application not supported"
| Binop bop -> (
(* Special case for GetYear comparisons *)
match bop, args with
| ( Lt KInt,
[
(EApp { f = EOp (Unop GetYear), _; args = [e1] }, _);
(ELit (LInt n), _);
] ) ->
let n = Runtime.integer_to_int n in
let ctx, e1 = translate_expr ctx e1 in
let e2 =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year n))
in
(* e2 corresponds to the first day of the year n. GetYear e1 < e2 can thus
be directly translated as < in the Z3 encoding using the number of
days *)
ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2
| ( Lte KInt,
[
(EApp { f = EOp (Unop GetYear), _; args = [e1] }, _);
(ELit (LInt n), _);
] ) ->
let ctx, e1 = translate_expr ctx e1 in
let nb_days = if is_leap_year n then 365 else 364 in
let n = Runtime.integer_to_int n in
(* We want that the year corresponding to e1 is smaller or equal to n. We
encode this as the day corresponding to e1 is smaller or equal than the
last day of the year [n], which is Jan 1st + 365 days if [n] is a leap
year, Jan 1st + 364 else *)
let e2 =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year n) + nb_days)
in
ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2
| ( Gt KInt,
[
(EApp { f = EOp (Unop GetYear), _; args = [e1] }, _);
(ELit (LInt n), _);
] ) ->
let ctx, e1 = translate_expr ctx e1 in
let nb_days = if is_leap_year n then 365 else 364 in
let n = Runtime.integer_to_int n in
(* We want that the year corresponding to e1 is greater to n. We encode
this as the day corresponding to e1 is greater than the last day of the
year [n], which is Jan 1st + 365 days if [n] is a leap year, Jan 1st +
364 else *)
let e2 =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year n) + nb_days)
in
ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2
| ( Gte KInt,
[
(EApp { f = EOp (Unop GetYear), _; args = [e1] }, _);
(ELit (LInt n), _);
] ) ->
let n = Runtime.integer_to_int n in
let ctx, e1 = translate_expr ctx e1 in
let e2 =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year n))
in
(* e2 corresponds to the first day of the year n. GetYear e1 >= e2 can
thus be directly translated as >= in the Z3 encoding using the number
of days *)
ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2
| ( Eq,
[
(EApp { f = EOp (Unop GetYear), _; args = [e1] }, _);
(ELit (LInt n), _);
] ) ->
let n = Runtime.integer_to_int n in
let ctx, e1 = translate_expr ctx e1 in
let min_date =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year n))
in
let max_date =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year (n + 1)))
in
( ctx,
Boolean.mk_and ctx.ctx_z3
[
Arithmetic.mk_ge ctx.ctx_z3 e1 min_date;
Arithmetic.mk_lt ctx.ctx_z3 e1 max_date;
] )
| _ -> (
let ctx, e1, e2 =
match args with
| [e1; e2] ->
let ctx, e1 = translate_expr ctx e1 in
let ctx, e2 = translate_expr ctx e2 in
ctx, e1, e2
| _ ->
Format.kasprintf failwith
"[Z3 encoding] Ill-formed binary operator application: %a"
(Shared_ast.Expr.format ctx.ctx_decl)
(Shared_ast.Expr.eapp
(Shared_ast.Expr.eop op (Untyped { pos = Pos.no_pos }))
(List.map Shared_ast.Expr.untype args)
(Untyped { pos = Pos.no_pos })
|> Shared_ast.Expr.unbox)
in
match bop with
| And -> ctx, Boolean.mk_and ctx.ctx_z3 [e1; e2]
| Or -> ctx, Boolean.mk_or ctx.ctx_z3 [e1; e2]
| Xor -> ctx, Boolean.mk_xor ctx.ctx_z3 e1 e2
| Add KInt | Add KRat | Add KMoney | Add KDate | Add KDuration ->
ctx, Arithmetic.mk_add ctx.ctx_z3 [e1; e2]
| Sub KInt | Sub KRat | Sub KMoney | Sub KDate | Sub KDuration ->
ctx, Arithmetic.mk_sub ctx.ctx_z3 [e1; e2]
| Mult KInt | Mult KRat | Mult KMoney | Mult KDate | Mult KDuration ->
ctx, Arithmetic.mk_mul ctx.ctx_z3 [e1; e2]
| Div KInt | Div KRat | Div KMoney ->
ctx, Arithmetic.mk_div ctx.ctx_z3 e1 e2
| Div _ ->
failwith
"[Z3 encoding] application of non-integer binary operator Div not \
supported"
| Lt KInt | Lt KRat | Lt KMoney | Lt KDate | Lt KDuration ->
ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2
| Lte KInt | Lte KRat | Lte KMoney | Lte KDate | Lte KDuration ->
ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2
| Gt KInt | Gt KRat | Gt KMoney | Gt KDate | Gt KDuration ->
ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2
| Gte KInt | Gte KRat | Gte KMoney | Gte KDate | Gte KDuration ->
ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2
| Eq -> ctx, Boolean.mk_eq ctx.ctx_z3 e1 e2
| Neq -> ctx, Boolean.mk_not ctx.ctx_z3 (Boolean.mk_eq ctx.ctx_z3 e1 e2)
| Map ->
failwith
"[Z3 encoding] application of binary operator Map not supported"
| Concat ->
failwith
"[Z3 encoding] application of binary operator Concat not supported"
| Filter ->
failwith
"[Z3 encoding] application of binary operator Filter not supported"))
| Unop uop -> (
let ctx, e1 =
match args with
| [e1] -> (
try translate_expr ctx e1
with Z3.Error s ->
Errors.raise_spanned_error (Shared_ast.Expr.pos e1) "%s" s)
| _ ->
Format.kasprintf failwith
"[Z3 encoding] Ill-formed unary operator application: %a"
(Shared_ast.Expr.format ctx.ctx_decl)
(Shared_ast.Expr.eapp
(Shared_ast.Expr.eop op (Untyped { pos = Pos.no_pos }))
(List.map Shared_ast.Expr.untype args)
(Untyped { pos = Pos.no_pos })
|> Shared_ast.Expr.unbox)
| ( Lt_int_int,
[
(EApp { f = EOp { op = GetYear; _ }, _; args = [e1] }, _);
(ELit (LInt n), _);
] ) ->
let n = Runtime.integer_to_int n in
let ctx, e1 = translate_expr ctx e1 in
let e2 =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (date_to_int (date_of_year n))
in
match uop with
| Not -> ctx, Boolean.mk_not ctx.ctx_z3 e1
| Minus _ ->
failwith "[Z3 encoding] application of unary operator Minus not supported"
(* Omitting the log from the VC *)
| Log _ -> ctx, e1
| Length ->
(* For now, an array is only its symbolic length. We simply return it *)
ctx, e1
| IntToRat ->
failwith
"[Z3 encoding] application of unary operator IntToRat not supported"
| MoneyToRat ->
failwith
"[Z3 encoding] application of unary operator MoneyToRat not supported"
| RatToMoney ->
failwith
"[Z3 encoding] application of unary operator RatToMoney not supported"
| GetDay ->
failwith
"[Z3 encoding] application of unary operator GetDay not supported"
| GetMonth ->
failwith
"[Z3 encoding] application of unary operator GetMonth not supported"
| GetYear ->
failwith
"[Z3 encoding] GetYear operator only supported in comparisons with \
literal"
| FirstDayOfMonth ->
failwith
"[Z3 encoding] FirstDayOfMonth operator only supported in comparisons \
with literal"
| LastDayOfMonth ->
failwith
"[Z3 encoding] LastDayOfMonth operator only supported in comparisons \
with literal"
| RoundDecimal ->
failwith "[Z3 encoding] RoundDecimal operator not implemented yet"
| RoundMoney ->
failwith "[Z3 encoding] RoundMoney operator not implemented yet")
(* e2 corresponds to the first day of the year n. GetYear e1 < e2 can thus
be directly translated as < in the Z3 encoding using the number of
days *)
ctx, Arithmetic.mk_lt ctx.ctx_z3 e1 e2
| ( Lte_int_int,
[
(EApp { f = EOp { op = GetYear; _ }, _; args = [e1] }, _);
(ELit (LInt n), _);
] ) ->
let ctx, e1 = translate_expr ctx e1 in
let nb_days = if is_leap_year n then 365 else 364 in
let n = Runtime.integer_to_int n in
(* We want that the year corresponding to e1 is smaller or equal to n. We
encode this as the day corresponding to e1 is smaller or equal than the
last day of the year [n], which is Jan 1st + 365 days if [n] is a leap
year, Jan 1st + 364 else *)
let e2 =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year n) + nb_days)
in
ctx, Arithmetic.mk_le ctx.ctx_z3 e1 e2
| ( Gt_int_int,
[
(EApp { f = EOp { op = GetYear; _ }, _; args = [e1] }, _);
(ELit (LInt n), _);
] ) ->
let ctx, e1 = translate_expr ctx e1 in
let nb_days = if is_leap_year n then 365 else 364 in
let n = Runtime.integer_to_int n in
(* We want that the year corresponding to e1 is greater to n. We encode this
as the day corresponding to e1 is greater than the last day of the year
[n], which is Jan 1st + 365 days if [n] is a leap year, Jan 1st + 364
else *)
let e2 =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year n) + nb_days)
in
ctx, Arithmetic.mk_gt ctx.ctx_z3 e1 e2
| ( Gte_int_int,
[
(EApp { f = EOp { op = GetYear; _ }, _; args = [e1] }, _);
(ELit (LInt n), _);
] ) ->
let n = Runtime.integer_to_int n in
let ctx, e1 = translate_expr ctx e1 in
let e2 =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (date_to_int (date_of_year n))
in
(* e2 corresponds to the first day of the year n. GetYear e1 >= e2 can thus
be directly translated as >= in the Z3 encoding using the number of
days *)
ctx, Arithmetic.mk_ge ctx.ctx_z3 e1 e2
| ( Eq,
[
(EApp { f = EOp { op = GetYear; _ }, _; args = [e1] }, _);
(ELit (LInt n), _);
] ) ->
let n = Runtime.integer_to_int n in
let ctx, e1 = translate_expr ctx e1 in
let min_date =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3 (date_to_int (date_of_year n))
in
let max_date =
Arithmetic.Integer.mk_numeral_i ctx.ctx_z3
(date_to_int (date_of_year (n + 1)))
in
( ctx,
Boolean.mk_and ctx.ctx_z3
[
Arithmetic.mk_ge ctx.ctx_z3 e1 min_date;
Arithmetic.mk_lt ctx.ctx_z3 e1 max_date;
] )
| And, _ -> app Boolean.mk_and
| Or, _ -> app Boolean.mk_or
| Xor, _ -> app2 Boolean.mk_xor
| (Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur | Add_dur_dur), _ ->
app Arithmetic.mk_add
| ( ( Sub_int_int | Sub_rat_rat | Sub_mon_mon | Sub_dat_dat | Sub_dat_dur
| Sub_dur_dur ),
_ ) ->
app Arithmetic.mk_sub
| (Mult_int_int | Mult_rat_rat | Mult_mon_rat | Mult_dur_int), _ ->
app Arithmetic.mk_mul
| (Div_int_int | Div_rat_rat | Div_mon_rat | Div_mon_mon), _ ->
app2 Arithmetic.mk_div
| (Lt_int_int | Lt_rat_rat | Lt_mon_mon | Lt_dat_dat | Lt_dur_dur), _ ->
app2 Arithmetic.mk_lt
| (Lte_int_int | Lte_rat_rat | Lte_mon_mon | Lte_dat_dat | Lte_dur_dur), _ ->
app2 Arithmetic.mk_le
| (Gt_int_int | Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur), _ ->
app2 Arithmetic.mk_gt
| (Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur), _ ->
app2 Arithmetic.mk_ge
| Eq, _ -> app2 Boolean.mk_eq
| Map, _ ->
failwith "[Z3 encoding] application of binary operator Map not supported"
| Concat, _ ->
failwith "[Z3 encoding] application of binary operator Concat not supported"
| Filter, _ ->
failwith "[Z3 encoding] application of binary operator Filter not supported"
| Not, _ -> app1 Boolean.mk_not
(* Omitting the log from the VC *)
| Log _, [e1] -> translate_expr ctx e1
| Length, [e1] ->
(* For now, an array is only its symbolic length. We simply return it *)
translate_expr ctx e1
| IntToRat, _ ->
failwith
"[Z3 encoding] application of unary operator IntToRat not supported"
| MoneyToRat, _ ->
failwith
"[Z3 encoding] application of unary operator MoneyToRat not supported"
| RatToMoney, _ ->
failwith
"[Z3 encoding] application of unary operator RatToMoney not supported"
| GetDay, _ ->
failwith "[Z3 encoding] application of unary operator GetDay not supported"
| GetMonth, _ ->
failwith
"[Z3 encoding] application of unary operator GetMonth not supported"
| GetYear, _ ->
failwith
"[Z3 encoding] GetYear operator only supported in comparisons with \
literal"
| FirstDayOfMonth, _ ->
failwith
"[Z3 encoding] FirstDayOfMonth operator only supported in comparisons \
with literal"
| LastDayOfMonth, _ ->
failwith
"[Z3 encoding] LastDayOfMonth operator only supported in comparisons \
with literal"
| RoundDecimal, _ ->
failwith "[Z3 encoding] RoundDecimal operator not implemented yet"
| RoundMoney, _ ->
failwith "[Z3 encoding] RoundMoney operator not implemented yet"
| _ -> ill_formed ()
(** [translate_expr] translate the expression [vc] to its corresponding Z3
expression **)
@ -780,7 +738,7 @@ and translate_expr (ctx : context) (vc : typed expr) : context * Expr.expr =
| EAbs _ -> failwith "[Z3 encoding] EAbs unsupported"
| EApp { f = head; args } -> (
match Marked.unmark head with
| EOp op -> translate_op ctx op args
| EOp { op; _ } -> translate_op ctx op args
| EVar v ->
let (Typed { ty = f_ty; _ }) = Marked.get_mark head in
let ctx, fd = find_or_create_funcdecl ctx v f_ty in

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -561,80 +561,8 @@ let handle_default_opt
let no_input : unit -> 'a = fun _ -> raise EmptyError
let ( *$ ) (i1 : money) (i2 : decimal) : money =
let i1_abs = Z.abs i1 in
let i2_abs = Q.abs i2 in
let sign_int = Z.sign i1 * Q.sign i2 in
let rat_result = Q.mul (Q.of_bigint i1_abs) i2_abs in
let res, remainder = Z.div_rem (Q.num rat_result) (Q.den rat_result) in
(* we perform nearest rounding when multiplying an amount of money by a
decimal !*)
if Z.(of_int 2 * remainder >= Q.den rat_result) then
Z.(add res (of_int 1) * of_int sign_int)
else Z.(res * of_int sign_int)
let ( /$ ) (m1 : money) (m2 : money) : decimal =
if Z.zero = m2 then raise Division_by_zero
else Q.div (Q.of_bigint m1) (Q.of_bigint m2)
let ( +$ ) (m1 : money) (m2 : money) : money = Z.add m1 m2
let ( -$ ) (m1 : money) (m2 : money) : money = Z.sub m1 m2
let ( ~-$ ) (m1 : money) : money = Z.sub Z.zero m1
let ( +! ) (i1 : integer) (i2 : integer) : integer = Z.add i1 i2
let ( -! ) (i1 : integer) (i2 : integer) : integer = Z.sub i1 i2
let ( ~-! ) (i1 : integer) : integer = Z.sub Z.zero i1
let ( *! ) (i1 : integer) (i2 : integer) : integer = Z.mul i1 i2
let ( /! ) (i1 : integer) (i2 : integer) : integer =
if Z.zero = i2 then raise Division_by_zero else Z.div i1 i2
let ( +& ) (i1 : decimal) (i2 : decimal) : decimal = Q.add i1 i2
let ( -& ) (i1 : decimal) (i2 : decimal) : decimal = Q.sub i1 i2
let ( ~-& ) (i1 : decimal) : decimal = Q.sub Q.zero i1
let ( *& ) (i1 : decimal) (i2 : decimal) : decimal = Q.mul i1 i2
let ( /& ) (i1 : decimal) (i2 : decimal) : decimal =
if Q.zero = i2 then raise Division_by_zero else Q.div i1 i2
let ( +@ ) : date -> duration -> date = Dates_calc.Dates.add_dates
let ( -@ ) : date -> date -> duration = Dates_calc.Dates.sub_dates
let ( +^ ) : duration -> duration -> duration = Dates_calc.Dates.add_periods
let ( -^ ) : duration -> duration -> duration = Dates_calc.Dates.sub_periods
let ( *^ ) (d : duration) (m : integer) : duration =
Dates_calc.Dates.mul_period d (Z.to_int m)
let ( <=$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 <= 0
let ( >=$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 >= 0
let ( <$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 < 0
let ( >$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 > 0
let ( =$ ) (m1 : money) (m2 : money) : bool = Z.compare m1 m2 = 0
let ( >=! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 >= 0
let ( <=! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 <= 0
let ( >! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 > 0
let ( <! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 < 0
let ( =! ) (i1 : integer) (i2 : integer) : bool = Z.compare i1 i2 = 0
let ( >=& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 >= 0
let ( <=& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 <= 0
let ( >& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 > 0
let ( <& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 < 0
let ( =& ) (i1 : decimal) (i2 : decimal) : bool = Q.compare i1 i2 = 0
let ( >=@ ) (d1 : date) (d2 : date) : bool =
Dates_calc.Dates.compare_dates d1 d2 >= 0
let ( <=@ ) (d1 : date) (d2 : date) : bool =
Dates_calc.Dates.compare_dates d1 d2 <= 0
let ( >@ ) (d1 : date) (d2 : date) : bool =
Dates_calc.Dates.compare_dates d1 d2 > 0
let ( <@ ) (d1 : date) (d2 : date) : bool =
Dates_calc.Dates.compare_dates d1 d2 < 0
let ( =@ ) (d1 : date) (d2 : date) : bool =
Dates_calc.Dates.compare_dates d1 d2 = 0
(* TODO: add a compare built-in to dates_calc. At the moment this fails on e.g.
[3 months, 4 months] *)
let compare_periods (p1 : duration) (p2 : duration) : int =
try
let p1_days = Dates_calc.Dates.period_to_days p1 in
@ -642,14 +570,102 @@ let compare_periods (p1 : duration) (p2 : duration) : int =
compare p1_days p2_days
with Dates_calc.Dates.AmbiguousComputation -> raise UncomparableDurations
let ( >=^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 >= 0
let ( <=^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 <= 0
let ( >^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 > 0
let ( <^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 < 0
let ( =^ ) (d1 : duration) (d2 : duration) : bool = compare_periods d1 d2 = 0
let ( ~-^ ) : duration -> duration = Dates_calc.Dates.neg_period
(* TODO: same here, although it was tweaked to never fail on equal dates.
Comparing the difference to duration_0 is not a good idea because we still
want to fail on [1 month, 30 days] rather than return [false] *)
let equal_periods (p1 : duration) (p2 : duration) : bool =
try Dates_calc.Dates.period_to_days (Dates_calc.Dates.sub_periods p1 p2) = 0
with Dates_calc.Dates.AmbiguousComputation -> raise UncomparableDurations
let array_filter (f : 'a -> bool) (a : 'a array) : 'a array =
Array.of_list (List.filter f (Array.to_list a))
module Oper = struct
let o_not = Stdlib.not
let o_length a = Z.of_int (Array.length a)
let o_intToRat = decimal_of_integer
let o_moneyToRat = decimal_of_money
let o_ratToMoney = money_of_decimal
let o_getDay = day_of_month_of_date
let o_getMonth = month_number_of_date
let o_getYear = year_of_date
let o_firstDayOfMonth = first_day_of_month
let o_lastDayOfMonth = last_day_of_month
let o_roundMoney = money_round
let o_roundDecimal = decimal_round
let o_minus_int i1 = Z.sub Z.zero i1
let o_minus_rat i1 = Q.sub Q.zero i1
let o_minus_mon m1 = Z.sub Z.zero m1
let o_minus_dur = Dates_calc.Dates.neg_period
let o_and = ( && )
let o_or = ( || )
let o_xor : bool -> bool -> bool = ( <> )
let o_eq = ( = )
let o_map = Array.map
let o_concat = Array.append
let o_filter f a = Array.of_list (List.filter f (Array.to_list a))
let o_add_int_int i1 i2 = Z.add i1 i2
let o_add_rat_rat i1 i2 = Q.add i1 i2
let o_add_mon_mon m1 m2 = Z.add m1 m2
let o_add_dat_dur da du = Dates_calc.Dates.add_dates da du
let o_add_dur_dur = Dates_calc.Dates.add_periods
let o_sub_int_int i1 i2 = Z.sub i1 i2
let o_sub_rat_rat i1 i2 = Q.sub i1 i2
let o_sub_mon_mon m1 m2 = Z.sub m1 m2
let o_sub_dat_dat = Dates_calc.Dates.sub_dates
let o_sub_dat_dur dat dur = Dates_calc.Dates.(add_dates dat (neg_period dur))
let o_sub_dur_dur = Dates_calc.Dates.sub_periods
let o_mult_int_int i1 i2 = Z.mul i1 i2
let o_mult_rat_rat i1 i2 = Q.mul i1 i2
let array_length (a : 'a array) : integer = Z.of_int (Array.length a)
let o_mult_mon_rat i1 i2 =
let i1_abs = Z.abs i1 in
let i2_abs = Q.abs i2 in
let sign_int = Z.sign i1 * Q.sign i2 in
let rat_result = Q.mul (Q.of_bigint i1_abs) i2_abs in
let res, remainder = Z.div_rem (Q.num rat_result) (Q.den rat_result) in
(* we perform nearest rounding when multiplying an amount of money by a
decimal !*)
if Z.(of_int 2 * remainder >= Q.den rat_result) then
Z.(add res (of_int 1) * of_int sign_int)
else Z.(res * of_int sign_int)
let o_mult_dur_int d m = Dates_calc.Dates.mul_period d (Z.to_int m)
let o_div_int_int i1 i2 = Z.div i1 i2 (* raises Division_by_zero *)
let o_div_rat_rat i1 i2 =
if Q.zero = i2 then raise Division_by_zero else Q.div i1 i2
let o_div_mon_mon m1 m2 =
if Z.zero = m2 then raise Division_by_zero
else Q.div (Q.of_bigint m1) (Q.of_bigint m2)
let o_div_mon_rat m1 r1 =
if Q.zero = r1 then raise Division_by_zero else o_mult_mon_rat m1 (Q.inv r1)
let o_lt_int_int i1 i2 = Z.compare i1 i2 < 0
let o_lt_rat_rat i1 i2 = Q.compare i1 i2 < 0
let o_lt_mon_mon m1 m2 = Z.compare m1 m2 < 0
let o_lt_dur_dur d1 d2 = compare_periods d1 d2 < 0
let o_lt_dat_dat d1 d2 = Dates_calc.Dates.compare_dates d1 d2 < 0
let o_lte_int_int i1 i2 = Z.compare i1 i2 <= 0
let o_lte_rat_rat i1 i2 = Q.compare i1 i2 <= 0
let o_lte_mon_mon m1 m2 = Z.compare m1 m2 <= 0
let o_lte_dur_dur d1 d2 = compare_periods d1 d2 <= 0
let o_lte_dat_dat d1 d2 = Dates_calc.Dates.compare_dates d1 d2 <= 0
let o_gt_int_int i1 i2 = Z.compare i1 i2 > 0
let o_gt_rat_rat i1 i2 = Q.compare i1 i2 > 0
let o_gt_mon_mon m1 m2 = Z.compare m1 m2 > 0
let o_gt_dur_dur d1 d2 = compare_periods d1 d2 > 0
let o_gt_dat_dat d1 d2 = Dates_calc.Dates.compare_dates d1 d2 > 0
let o_gte_int_int i1 i2 = Z.compare i1 i2 >= 0
let o_gte_rat_rat i1 i2 = Q.compare i1 i2 >= 0
let o_gte_mon_mon m1 m2 = Z.compare m1 m2 >= 0
let o_gte_dur_dur d1 d2 = compare_periods d1 d2 >= 0
let o_gte_dat_dat d1 d2 = Dates_calc.Dates.compare_dates d1 d2 >= 0
let o_eq_int_int i1 i2 = Z.equal i1 i2
let o_eq_rat_rat i1 i2 = Q.equal i1 i2
let o_eq_mon_mon m1 m2 = Z.equal m1 m2
let o_eq_dur_dur d1 d2 = equal_periods d1 d2
let o_eq_dat_dat d1 d2 = Dates_calc.Dates.compare_dates d1 d2 = 0
let o_fold = Array.fold_left
end
include Oper

View File

@ -285,85 +285,76 @@ val no_input : unit -> 'a
(**{1 Operators} *)
(**{2 Money} *)
module Oper : sig
(* The types **must** match with Shared_ast.Operator.*_type *)
val o_not : bool -> bool
val o_length : 'a array -> integer
val o_intToRat : integer -> decimal
val o_moneyToRat : money -> decimal
val o_ratToMoney : decimal -> money
val o_getDay : date -> integer
val o_getMonth : date -> integer
val o_getYear : date -> integer
val o_firstDayOfMonth : date -> date
val o_lastDayOfMonth : date -> date
val o_roundMoney : money -> money
val o_roundDecimal : decimal -> decimal
val o_minus_int : integer -> integer
val o_minus_rat : decimal -> decimal
val o_minus_mon : money -> money
val o_minus_dur : duration -> duration
val o_and : bool -> bool -> bool
val o_or : bool -> bool -> bool
val o_xor : bool -> bool -> bool
val o_eq : 'a -> 'a -> bool
val o_map : ('a -> 'b) -> 'a array -> 'b array
val o_concat : 'a array -> 'a array -> 'a array
val o_filter : ('a -> bool) -> 'a array -> 'a array
val o_add_int_int : integer -> integer -> integer
val o_add_rat_rat : decimal -> decimal -> decimal
val o_add_mon_mon : money -> money -> money
val o_add_dat_dur : date -> duration -> date
val o_add_dur_dur : duration -> duration -> duration
val o_sub_int_int : integer -> integer -> integer
val o_sub_rat_rat : decimal -> decimal -> decimal
val o_sub_mon_mon : money -> money -> money
val o_sub_dat_dat : date -> date -> duration
val o_sub_dat_dur : date -> duration -> date
val o_sub_dur_dur : duration -> duration -> duration
val o_mult_int_int : integer -> integer -> integer
val o_mult_rat_rat : decimal -> decimal -> decimal
val o_mult_mon_rat : money -> decimal -> money
val o_mult_dur_int : duration -> integer -> duration
val o_div_int_int : integer -> integer -> integer
val o_div_rat_rat : decimal -> decimal -> decimal
val o_div_mon_mon : money -> money -> decimal
val o_div_mon_rat : money -> decimal -> money
val o_lt_int_int : integer -> integer -> bool
val o_lt_rat_rat : decimal -> decimal -> bool
val o_lt_mon_mon : money -> money -> bool
val o_lt_dur_dur : duration -> duration -> bool
val o_lt_dat_dat : date -> date -> bool
val o_lte_int_int : integer -> integer -> bool
val o_lte_rat_rat : decimal -> decimal -> bool
val o_lte_mon_mon : money -> money -> bool
val o_lte_dur_dur : duration -> duration -> bool
val o_lte_dat_dat : date -> date -> bool
val o_gt_int_int : integer -> integer -> bool
val o_gt_rat_rat : decimal -> decimal -> bool
val o_gt_mon_mon : money -> money -> bool
val o_gt_dur_dur : duration -> duration -> bool
val o_gt_dat_dat : date -> date -> bool
val o_gte_int_int : integer -> integer -> bool
val o_gte_rat_rat : decimal -> decimal -> bool
val o_gte_mon_mon : money -> money -> bool
val o_gte_dur_dur : duration -> duration -> bool
val o_gte_dat_dat : date -> date -> bool
val o_eq_int_int : integer -> integer -> bool
val o_eq_rat_rat : decimal -> decimal -> bool
val o_eq_mon_mon : money -> money -> bool
val o_eq_dur_dur : duration -> duration -> bool
val o_eq_dat_dat : date -> date -> bool
val o_fold : ('a -> 'b -> 'a) -> 'a -> 'b array -> 'a
end
val ( *$ ) : money -> decimal -> money
val ( /$ ) : money -> money -> decimal
(** @raise Division_by_zero *)
val ( +$ ) : money -> money -> money
val ( -$ ) : money -> money -> money
val ( ~-$ ) : money -> money
val ( =$ ) : money -> money -> bool
val ( <=$ ) : money -> money -> bool
val ( >=$ ) : money -> money -> bool
val ( <$ ) : money -> money -> bool
val ( >$ ) : money -> money -> bool
(**{2 Integers} *)
val ( +! ) : integer -> integer -> integer
val ( -! ) : integer -> integer -> integer
val ( ~-! ) : integer -> integer
val ( *! ) : integer -> integer -> integer
val ( /! ) : integer -> integer -> integer
(** @raise Division_by_zero *)
val ( =! ) : integer -> integer -> bool
val ( >=! ) : integer -> integer -> bool
val ( <=! ) : integer -> integer -> bool
val ( >! ) : integer -> integer -> bool
val ( <! ) : integer -> integer -> bool
(** {2 Decimals} *)
val ( +& ) : decimal -> decimal -> decimal
val ( -& ) : decimal -> decimal -> decimal
val ( ~-& ) : decimal -> decimal
val ( *& ) : decimal -> decimal -> decimal
val ( /& ) : decimal -> decimal -> decimal
(** @raise Division_by_zero *)
val ( =& ) : decimal -> decimal -> bool
val ( >=& ) : decimal -> decimal -> bool
val ( <=& ) : decimal -> decimal -> bool
val ( >& ) : decimal -> decimal -> bool
val ( <& ) : decimal -> decimal -> bool
(** {2 Dates} *)
val ( +@ ) : date -> duration -> date
val ( -@ ) : date -> date -> duration
val ( =@ ) : date -> date -> bool
val ( >=@ ) : date -> date -> bool
val ( <=@ ) : date -> date -> bool
val ( >@ ) : date -> date -> bool
val ( <@ ) : date -> date -> bool
(** {2 Durations} *)
val ( +^ ) : duration -> duration -> duration
val ( -^ ) : duration -> duration -> duration
val ( *^ ) : duration -> integer -> duration
val ( ~-^ ) : duration -> duration
val ( =^ ) : duration -> duration -> bool
val ( >=^ ) : duration -> duration -> bool
(** @raise UncomparableDurations *)
val ( <=^ ) : duration -> duration -> bool
(** @raise UncomparableDurations *)
val ( >^ ) : duration -> duration -> bool
(** @raise UncomparableDurations *)
val ( <^ ) : duration -> duration -> bool
(** @raise UncomparableDurations *)
(** {2 Arrays} *)
val array_filter : ('a -> bool) -> 'a array -> 'a array
val array_length : 'a array -> integer
include module type of Oper

View File

@ -150,7 +150,12 @@ class Money:
return Money(Integer(res))
def __truediv__(self, other: 'Money') -> Decimal:
return Decimal(mpq(self.value.value / other.value.value))
if isinstance(other, Money):
return Decimal(mpq(self.value.value / other.value.value))
elif isinstance(other, Decimal):
return self * (1. / other.value)
else:
raise Exception("Dividing money and invalid obj")
def __neg__(self: 'Money') -> 'Money':
return Money(- self.value)
@ -193,8 +198,13 @@ class Date:
def __add__(self, other: 'Duration') -> 'Date':
return Date(self.value + other.value)
def __sub__(self, other: 'Date') -> 'Duration':
return Duration(dateutil.relativedelta.relativedelta(self.value, other.value))
def __sub__(self, other: object) -> object:
if isinstance(other, Date):
return Duration(dateutil.relativedelta.relativedelta(self.value, other.value))
elif isinstance(other, Duration):
return Date(self.value - other.value)
else:
raise Exception("Substracting date and invalid obj")
def __lt__(self, other: 'Date') -> bool:
return self.value < other.value

View File

@ -12,15 +12,13 @@ scope A:
```catala-test-inline
$ catala Interpret -s A
[ERROR] Error during typechecking, incompatible types:
--> integer
--> money
[ERROR] I don't know how to apply operator >= on types integer and
money
Error coming from typechecking the following expression:
┌─⯈ tests/test_array/bad/fold_error.catala_en:10.61-62:
┌─⯈ tests/test_array/bad/fold_error.catala_en:10.63-66:
└──┐
10 │ definition list_high_count equals number for m in list of (m >=$ $7)
│ ‾
‾‾
└─ Article
Type integer coming from expression:
@ -31,10 +29,10 @@ Type integer coming from expression:
└─ Article
Type money coming from expression:
┌─⯈ tests/test_array/bad/fold_error.catala_en:10.63-66:
┌─⯈ tests/test_array/bad/fold_error.catala_en:10.67-69:
└──┐
10 │ definition list_high_count equals number for m in list of (m >=$ $7)
‾‾
‾‾
└─ Article
#return code 255#
```

View File

@ -15,17 +15,17 @@ $ catala Typecheck
--> bool
Error coming from typechecking the following expression:
┌─⯈ tests/test_bool/bad/test_xor_with_int.catala_en:8.36-38:
┌─⯈ tests/test_bool/bad/test_xor_with_int.catala_en:8.29-31:
└─┐
8 │ definition test_var equals 10 xor 20
‾‾
│ ‾‾
└─ 'xor' should be a boolean operator
Type integer coming from expression:
┌─⯈ tests/test_bool/bad/test_xor_with_int.catala_en:8.36-38:
┌─⯈ tests/test_bool/bad/test_xor_with_int.catala_en:8.29-31:
└─┐
8 │ definition test_var equals 10 xor 20
‾‾
│ ‾‾
└─ 'xor' should be a boolean operator
Type bool coming from expression:

View File

@ -24,7 +24,7 @@ let TestBool :
in
let foo1 : bool = error_empty
⟨foo () | true ⊢
⟨⟨bar1 >= 0 ⊢ true⟩, ⟨bar1 < 0 ⊢ false⟩ | false ⊢
⟨⟨bar1 >=! 0 ⊢ true⟩, ⟨bar1 <! 0 ⊢ false⟩ | false ⊢
∅ ⟩⟩ in
TestBool { "foo"= foo1; "bar"= bar1 } in
TestBool
@ -47,5 +47,5 @@ struct TestBool = {
let scope TestBool (foo: bool|context|output) (bar: integer|context|output) =
let bar : integer = reentrant or by default ⟨true ⊢ 1⟩;
let foo : bool = reentrant or by default
⟨⟨bar >= 0 ⊢ true⟩, ⟨bar < 0 ⊢ false⟩ | false ⊢ ∅ ⟩
⟨⟨bar >=! 0 ⊢ true⟩, ⟨bar <! 0 ⊢ false⟩ | false ⊢ ∅ ⟩
```

View File

@ -19,7 +19,7 @@ scope B:
$ catala Scopelang -s B
let scope B (b: bool|input) =
let a.f : integer → integer =
λ (param: integer) → ⟨b && param > 0 ⊢ param - 1⟩;
λ (param: integer) → ⟨b && param >! 0 ⊢ param -! 1⟩;
call A[a]
```
@ -30,7 +30,7 @@ let A =
let f : integer → integer = A_in."f_in" in
let f1 : integer → integer =
λ (param: integer) → error_empty
⟨f param | true ⊢ ⟨true ⊢ param + 1⟩⟩ in
⟨f param | true ⊢ ⟨true ⊢ param +! 1⟩⟩ in
A { }
```
@ -40,7 +40,7 @@ let B =
λ (B_in: B_in {"b_in": bool}) →
let b : bool = B_in."b_in" in
let a.f : integer → integer =
λ (param: integer) → ⟨b && param > 0 ⊢ param - 1⟩ in
λ (param: integer) → ⟨b && param >! 0 ⊢ param -! 1⟩ in
let result : A {} = A (A_in { "f_in"= a.f }) in
B { }
```

View File

@ -27,11 +27,11 @@ let A =
let e : unit → integer = A_in."e_in" in
let f : unit → integer = A_in."f_in" in
let a : integer = error_empty ⟨true ⊢ 0⟩ in
let b : integer = error_empty ⟨true ⊢ a + 1⟩ in
let b : integer = error_empty ⟨true ⊢ a +! 1⟩ in
let e1 : integer = error_empty
⟨e () | true ⊢ ⟨true ⊢ b + c + d + 1⟩⟩ in
⟨e () | true ⊢ ⟨true ⊢ b +! c +! d +! 1⟩⟩ in
let f1 : integer = error_empty
⟨f () | true ⊢ ⟨true ⊢ e1 + 1⟩⟩ in
⟨f () | true ⊢ ⟨true ⊢ e1 +! 1⟩⟩ in
A { "b"= b; "d"= d; "f"= f1 }
```

View File

@ -31,10 +31,10 @@ Type decimal coming from expression:
Type collection coming from expression:
┌─⯈ tests/test_typing/bad/err2.catala_en:10.22-28:
┌─⯈ tests/test_typing/bad/err2.catala_en:10.35-37:
└──┐
10 │ definition a equals number of (z ++ 1.1) / 2
‾‾‾‾‾‾
‾‾
#return code 255#
```

View File

@ -0,0 +1,68 @@
```catala
declaration scope S:
internal i1 content integer
internal i2 content integer
internal x1 content decimal
internal x2 content decimal
internal m1 content money
internal m2 content money
internal d1 content duration
internal d2 content duration
internal t1 content date
internal t2 content date
output o_i content integer
output o_x content decimal
output o_m content money
output o_d content duration
output o_t content date
output o_b content boolean
scope S:
definition i1 equals 1
definition i2 equals 2
definition x1 equals 3.
definition x2 equals 4.
definition m1 equals $5
definition m2 equals $6
definition d1 equals 7 day
definition d2 equals 8 day
definition t1 equals |2022-01-09|
definition t2 equals |2022-01-10|
definition o_i equals -i1 + i2 - i1 * i2 / (i1 + i2)
definition o_x equals -x1 + x2 - x1 * x2 / (x1 + x2)
definition o_m equals -m1 + m2 - m1 * x2 / (x1 * m1 / m2) + m1 / x2
definition o_d equals -d1 + d2 - d1 * i2
definition o_t equals d1 + t1 + d1 + (t2 - t1)
definition o_b equals
i1 < i2 and x1 < x2 and m1 < m2 and d1 < d2 and t1 < t2 and
i1 <= i2 and x1 <= x2 and m1 <= m2 and d1 <= d2 and t1 <= t2 and
not (
i1 > i2 or x1 > x2 or m1 > m2 or d1 > d2 or t1 > t2 or
i1 >= i2 or x1 >= x2 or m1 >= m2 or d1 >= d2 or t1 >= t2
)
assertion o_i = -i1 +! i2 -! i1 *! i2 /! (i1 +! i2)
assertion o_x = -.x1 +. x2 -. x1 *. x2 /. (x1 +. x2)
assertion o_m = -$m1 +$ m2 -$ m1 *$ x2 / (m1 *$ x1 /$ m2) +$ m1 / x2
assertion o_d = -^d1 +^ d2 -^ d1 *^ i2
assertion o_t = t1 +@ d1 +@ (t2 -@ t1) +@ d1
assertion o_b =
i1 <! i2 and x1 <. x2 and m1 <$ m2 and d1 <^ d2 and t1 <@ t2 and
i1 <=! i2 and x1 <=. x2 and m1 <=$ m2 and d1 <=^ d2 and t1 <=@ t2 and
not (
i1 >! i2 or x1 >. x2 or m1 >$ m2 or d1 >^ d2 or t1 >@ t2 or
i1 >=! i2 or x1 >=. x2 or m1 >=$ m2 or d1 >=^ d2 or t1 >=@ t2
)
```
```catala-test-inline
$ catala Interpret -s S
[RESULT] Computation successful! Results:
[RESULT] o_b = true
[RESULT] o_d = [0 years, 0 months, -13 days]
[RESULT] o_i = 1
[RESULT] o_m = $-5.75
[RESULT] o_t = 2022-01-24
[RESULT] o_x = -0.71428571428571428571…
```

View File

@ -35,8 +35,8 @@ $ catala Scopelang -s A
let scope A (foo_bar: integer|context) (foo_baz: integer|internal)
(foo_fizz: integer|internal|output) =
let foo_bar : integer = reentrant or by default ⟨true ⊢ 1⟩;
let foo_baz : integer = ⟨true ⊢ foo_bar + 1⟩;
let foo_fizz : integer = ⟨true ⊢ foo_baz + 1⟩
let foo_baz : integer = ⟨true ⊢ foo_bar +! 1⟩;
let foo_fizz : integer = ⟨true ⊢ foo_baz +! 1⟩
```
```catala-test-inline