Translation now typechecks

This commit is contained in:
Denis Merigoux 2021-06-24 22:55:27 +02:00
parent 95b34937a6
commit 3f5027e5a5
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
5 changed files with 2033 additions and 1723 deletions

View File

@ -228,6 +228,8 @@ build_french_law_library_js: generate_french_law_library_ocaml format
#> generate_french_law_library_python : Generates the French law library Python sources from Catala
generate_french_law_library_python:\
$(FRENCH_LAW_PYTHON_LIB_DIR)/allocations_familiales.py
. $(FRENCH_LAW_PYTHON_LIB_DIR)/env/bin/activate ;\
$(MAKE) -C $(FRENCH_LAW_PYTHON_LIB_DIR) format
#> type_french_law_library_python : Types the French law library Python sources with mypy
type_french_law_library_python: generate_french_law_library_python

View File

@ -6,6 +6,9 @@ dependencies:
type:
mypy $(SOURCES)
format:
autopep8 --in-place $(SOURCES)
doc:
mkdir -p doc
sphinx-build ./ doc

File diff suppressed because it is too large Load Diff

View File

@ -11,26 +11,247 @@
from gmpy2 import log2, mpz, mpq, mpfr, mpc # type: ignore
import datetime
import dateutil.relativedelta # type: ignore
from typing import NewType, List, Callable, Tuple, Optional, TypeVar, Iterable
import dateutil.relativedelta
from typing import NewType, List, Callable, Tuple, Optional, TypeVar, Iterable, Union
from functools import reduce
Alpha = TypeVar('Alpha')
Beta = TypeVar('Beta')
# =====
# Types
# =====
# ============
# Type classes
# ============
Integer = NewType('Integer', object)
Decimal = NewType('Decimal', object)
Money = NewType('Money', Integer)
Date = NewType('Date', datetime.date)
Duration = NewType('Duration', object)
class Integer:
def __init__(self, value: Union[str, int]) -> None:
self.value = mpz(value)
def __add__(self, other: 'Integer') -> 'Integer':
return Integer(self.value + other.value)
def __sub__(self, other: 'Integer') -> 'Integer':
return Integer(self.value - other.value)
def __mul__(self, other: 'Integer') -> 'Integer':
return Integer(self.value * other.value)
def __truediv__(self, other: 'Integer') -> 'Integer':
return Integer(self.value // other.value)
def __neg__(self: 'Integer') -> 'Integer':
return Integer(- self.value)
def __lt__(self, other: 'Integer') -> bool:
return self.value < other.value
def __le__(self, other: 'Integer') -> bool:
return self.value <= other.value
def __gt__(self, other: 'Integer') -> bool:
return self.value > other.value
def __ge__(self, other: 'Integer') -> bool:
return self.value >= other.value
def __ne__(self, other: object) -> bool:
if isinstance(other, Integer):
return self.value != other.value
else:
return True
def __eq__(self, other: object) -> bool:
if isinstance(other, Integer):
return self.value == other.value
else:
return False
class Decimal:
def __init__(self, value: Union[str, int, float, Integer]) -> None:
self.value = mpq(value)
def __add__(self, other: 'Decimal') -> 'Decimal':
return Decimal(self.value + other.value)
def __sub__(self, other: 'Decimal') -> 'Decimal':
return Decimal(self.value - other.value)
def __mul__(self, other: 'Decimal') -> 'Decimal':
return Decimal(self.value * other.value)
def __truediv__(self, other: 'Decimal') -> 'Decimal':
return Decimal(self.value / other.value)
def __neg__(self: 'Decimal') -> 'Decimal':
return Decimal(- self.value)
def __lt__(self, other: 'Decimal') -> bool:
return self.value < other.value
def __le__(self, other: 'Decimal') -> bool:
return self.value <= other.value
def __gt__(self, other: 'Decimal') -> bool:
return self.value > other.value
def __ge__(self, other: 'Decimal') -> bool:
return self.value >= other.value
def __ne__(self, other: object) -> bool:
if isinstance(other, Decimal):
return self.value != other.value
else:
return True
def __eq__(self, other: object) -> bool:
if isinstance(other, Decimal):
return self.value == other.value
else:
return False
class Money:
def __init__(self, value: Integer) -> None:
self.value = value
def __add__(self, other: 'Money') -> 'Money':
return Money(self.value + other.value)
def __sub__(self, other: 'Money') -> 'Money':
return Money(self.value - other.value)
def __mul__(self, other: Decimal) -> 'Money':
return Money(Integer(self.value.value * other.value))
def __truediv__(self, other: 'Money') -> Decimal:
return Decimal(mpq(self.value.value / other.value.value))
def __neg__(self: 'Money') -> 'Money':
return Money(- self.value)
def __lt__(self, other: 'Money') -> bool:
return self.value < other.value
def __le__(self, other: 'Money') -> bool:
return self.value <= other.value
def __gt__(self, other: 'Money') -> bool:
return self.value > other.value
def __ge__(self, other: 'Money') -> bool:
return self.value >= other.value
def __ne__(self, other: object) -> bool:
if isinstance(other, Money):
return self.value != other.value
else:
return True
def __eq__(self, other: object) -> bool:
if isinstance(other, Money):
return self.value == other.value
else:
return False
class Date:
def __init__(self, value: datetime.date) -> None:
self.value = value
def __add__(self, other: 'Duration') -> 'Date':
return Date(self.value + other.value)
def __sub__(self, other: 'Date') -> 'Duration':
# Careful: invert argument order
return Duration(dateutil.relativedelta.relativedelta(other.value, self.value))
def __lt__(self, other: 'Date') -> bool:
return self.value < other.value
def __le__(self, other: 'Date') -> bool:
return self.value <= other.value
def __gt__(self, other: 'Date') -> bool:
return self.value > other.value
def __ge__(self, other: 'Date') -> bool:
return self.value >= other.value
def __ne__(self, other: object) -> bool:
if isinstance(other, Date):
return self.value != other.value
else:
return True
def __eq__(self, other: object) -> bool:
if isinstance(other, Date):
return self.value == other.value
else:
return False
class Duration:
def __init__(self, value: dateutil.relativedelta.relativedelta) -> None:
self.value = value
def __add__(self, other: 'Duration') -> 'Duration':
return Duration(self.value + other.value)
def __sub__(self, other: 'Duration') -> 'Duration':
return Duration(self.value - other.value)
def __neg__(self: 'Duration') -> 'Duration':
return Duration(- self.value)
def __lt__(self, other: 'Duration') -> bool:
x = self.value.normalized()
y = other.value.normalized()
if (x.years != 0 or y.years != 0 or x.months != 0 or y.months != 0):
raise Exception("Can only compare durations expressed in days")
else:
return x.days < y.days
def __le__(self, other: 'Duration') -> bool:
x = self.value.normalized()
y = other.value.normalized()
if (x.years != 0 or y.years != 0 or x.months != 0 or y.months != 0):
raise Exception("Can only compare durations expressed in days")
else:
return x.days <= y.days
def __gt__(self, other: 'Duration') -> bool:
x = self.value.normalized()
y = other.value.normalized()
if (x.years != 0 or y.years != 0 or x.months != 0 or y.months != 0):
raise Exception("Can only compare durations expressed in days")
else:
return x.days > y.days
def __ge__(self, other: 'Duration') -> bool:
x = self.value.normalized()
y = other.value.normalized()
if (x.years != 0 or y.years != 0 or x.months != 0 or y.months != 0):
raise Exception("Can only compare durations expressed in days")
else:
return x.days >= y.days
def __ne__(self, other: object) -> bool:
if isinstance(other, Duration):
return self.value != other.value
else:
return True
def __eq__(self, other: object) -> bool:
if isinstance(other, Duration):
return self.value == other.value
else:
return False
class Unit:
pass
def __init__(self) -> None:
...
class SourcePosition:
@ -69,58 +290,6 @@ class NoValueProvided(Exception):
def __init__(self, source_position: SourcePosition) -> None:
self.source_position = SourcePosition
def raise_(ex):
raise ex
class TryCatch:
def __init__(self, fun, *args, **kwargs):
self.fun = fun
self.args = args
self.kwargs = kwargs
self.exception_types_and_handlers = []
self.finalize = None
def rescue(self, exception_types, handler):
if not isinstance(exception_types, Iterable):
exception_types = (exception_types,)
self.exception_types_and_handlers.append((exception_types, handler))
return self
def ensure(self, finalize, *finalize_args, **finalize_kwargs):
if self.finalize is not None:
raise Exception('ensure() called twice')
self.finalize = finalize
self.finalize_args = finalize_args
self.finalize_kwargs = finalize_kwargs
return self
def __call__(self):
try:
return self.fun(*self.args, **self.kwargs)
except BaseException as exc:
handler = self.find_applicable_handler(exc)
if handler is None:
raise
return handler(exc)
finally:
if self.finalize is not None:
self.finalize()
def find_applicable_handler(self, exc):
applicable_handlers = (
handler
for exception_types, handler in self.exception_types_and_handlers
if isinstance(exc, exception_types)
)
return next(applicable_handlers, None)
# ============================
# Constructors and conversions
# ============================
@ -131,19 +300,19 @@ class TryCatch:
def money_of_cents_string(v: str) -> Money:
return Money(mpz(v))
return Money(Integer(v))
def money_of_cents_int(v: int) -> Money:
return Money(mpz(v))
return Money(Integer(v))
def money_of_cents_integer(v: Integer) -> Money:
return Money(mpz(v))
return Money(v)
def money_to_float(m: Money) -> float:
return float(mpfr(mpq(m, 100)))
return float(mpfr(mpq(m.value.value, 100)))
def money_to_string(m: Money) -> str:
@ -151,7 +320,7 @@ def money_to_string(m: Money) -> str:
def money_to_cents(m: Money) -> Integer:
return m
return m.value
# --------
# Decimals
@ -159,23 +328,23 @@ def money_to_cents(m: Money) -> Integer:
def decimal_of_string(d: str) -> Decimal:
return Decimal(mpq(d))
return Decimal(d)
def decimal_to_float(d: Decimal) -> float:
return float(mpfr(d))
return float(mpfr(d.value))
def decimal_of_float(d: float) -> Decimal:
return Decimal(mpq(d))
return Decimal(d)
def decimal_of_integer(d: Integer) -> Decimal:
return Decimal(mpq(d))
return Decimal(d.value)
def decimal_to_string(precision: int, i: Decimal) -> str:
return "{1:.{0}}".format(precision, mpfr(i, precision * 10 // 2))
return "{1:.{0}}".format(precision, mpfr(i.value, precision * 10 // 2))
# --------
# Integers
@ -183,19 +352,19 @@ def decimal_to_string(precision: int, i: Decimal) -> str:
def integer_of_string(s: str) -> Integer:
return Integer(mpz(s))
return Integer(s)
def integer_to_string(d: Integer) -> str:
return str(d)
return str(d.value)
def integer_of_int(d: int) -> Integer:
return Integer(mpz(d))
return Integer(d)
def integer_to_int(d: Integer) -> int:
return int(d) # type: ignore
return int(d.value)
def integer_exponentiation(i: Integer, e: int) -> Integer:
@ -203,7 +372,7 @@ def integer_exponentiation(i: Integer, e: int) -> Integer:
def integer_log2(i: Integer) -> int:
return int(log2(i))
return int(log2(i.value))
# -----
# Dates
@ -211,19 +380,19 @@ def integer_log2(i: Integer) -> int:
def day_of_month_of_date(d: Date) -> Integer:
return integer_of_int(d.day)
return integer_of_int(d.value.day)
def month_number_of_date(d: Date) -> Integer:
return integer_of_int(d.month)
return integer_of_int(d.value.month)
def year_of_date(d: Date) -> Integer:
return integer_of_int(d.year)
return integer_of_int(d.value.year)
def date_to_string(d: Date) -> str:
return "{}".format(d)
return "{}".format(d.value)
def date_of_numbers(year: int, month: int, day: int) -> Date:
@ -239,11 +408,11 @@ def duration_of_numbers(years: int, months: int, days: int) -> Duration:
def duration_to_years_months_days(d: Duration) -> Tuple[int, int, int]:
return (d.years, d.months, d.days) # type: ignore
return (d.value.years, d.value.months, d.value.days) # type: ignore
def duration_to_string(s: Duration) -> str:
return "{}".format(s)
return "{}".format(s.value)
# -----
# Lists
@ -297,9 +466,10 @@ def handle_default(
return acc
def no_input() -> Callable[[], Alpha]:
# From https://stackoverflow.com/questions/8294618/define-a-lambda-expression-that-raises-an-exception
return (_ for _ in ()).throw(EmptyError)
def no_input() -> Callable[[Unit], Alpha]:
def closure(_: Unit()):
raise EmptyError
return closure
# =======
# Logging

View File

@ -2,4 +2,6 @@ gmpy2
typing
mypy
python-dateutil
types-python-dateutil
sphinx
autopep8