Thunking justifications and conclusion in avoid_translation pass

This commit is contained in:
adelaett 2023-04-11 17:46:41 +02:00
parent 83553d5950
commit cc66023e51
No known key found for this signature in database
GPG Key ID: 367A8C08F513BD65
9 changed files with 746 additions and 22 deletions

View File

@ -253,7 +253,7 @@ let rec trans ctx (e : 'm D.expr) : (lcalc, 'm mark) boxed_gexpr =
let m' = match m with Typed m -> Typed { m with ty = TAny, pos } in
Expr.make_app
(Expr.eop Op.HandleDefaultOpt [TAny, pos; TAny, pos; TAny, pos] m')
[Expr.earray excepts' m; just'; cons']
[Expr.earray excepts' m; Expr.thunk_term just' m; Expr.thunk_term cons' m]
pos
| ELit l -> monad_return ~mark (Expr.elit l m)
| EEmptyError -> monad_empty ~mark

View File

@ -753,10 +753,18 @@ let make_app ?(decl_ctx = None) e args pos =
in
eapp e args mark
let empty_thunked_term mark =
let thunk_term term mark =
let silent = Var.make "_" in
let pos = mark_pos mark in
make_abs [| silent |] (Bindlib.box EEmptyError, mark) [TLit TUnit, pos] pos
make_abs [| silent |] term [TLit TUnit, pos] pos
let empty_thunked_term mark = thunk_term (Bindlib.box EEmptyError, mark) mark
(* let unthunk_term term mark = let pos = mark_pos mark in make_app term [elit
LUnit mark] pos *)
let unthunk_term_nobox term mark =
Marked.mark mark (EApp { f = term; args = [ELit LUnit, mark] })
let make_let_in x tau e1 e2 mpos =
make_app (make_abs [| x |] e2 [tau] mpos) [e1] (pos e2)

View File

@ -274,7 +274,6 @@ val make_abs :
('a any, 'm mark) boxed_gexpr
val make_app :
?decl_ctx:decl_ctx option ->
('a any, 'm mark) boxed_gexpr ->
('a, 'm mark) boxed_gexpr list ->
@ -284,6 +283,17 @@ val make_app :
val empty_thunked_term :
'm mark -> ([< all > `DefaultTerms ], 'm mark) boxed_gexpr
val thunk_term :
(([< all ] as 'a), 'b mark) boxed_gexpr ->
'b mark ->
('a, 'b mark) boxed_gexpr
(* val unthunk_term : (([< all ] as 'a), Pos.t) boxed_gexpr -> 'b mark -> ('a,
Pos.t) boxed_gexpr *)
val unthunk_term_nobox :
(([< all ] as 'a), 'm mark) gexpr -> 'm mark -> ('a, 'm mark) gexpr
val make_let_in :
('a, 'm mark) gexpr Var.t ->
typ ->

View File

@ -340,8 +340,7 @@ let rec evaluate_operator
ELit (LBool (o_eq_dat_dat x y))
| Eq_dur_dur, [(ELit (LDuration x), _); (ELit (LDuration y), _)] ->
ELit (LBool (protect o_eq_dur_dur x y))
| HandleDefaultOpt, [(EArray exps, _); (juststification, _); (conclusion, _)]
-> (
| HandleDefaultOpt, [(EArray exps, _); justification; conclusion] -> (
let valid_exceptions =
ListLabels.filter exps ~f:(function
| EInj { name; cons; _ }, _
@ -357,11 +356,14 @@ let rec evaluate_operator
match valid_exceptions with
| [] -> (
match juststification with
match
Marked.unmark
(evaluate_expr ctx (Expr.unthunk_term_nobox justification m))
with
| EInj { name; cons; e = ELit (LBool true), _ }
when EnumName.equal name Definitions.option_enum
&& EnumConstructor.equal cons Definitions.some_constr ->
conclusion
Marked.unmark (evaluate_expr ctx (Expr.unthunk_term_nobox conclusion m))
| EInj { name; cons; e = (ELit (LBool false), _) as e }
when EnumName.equal name Definitions.option_enum
&& EnumConstructor.equal cons Definitions.some_constr ->
@ -454,7 +456,7 @@ let rec evaluate_expr :
| Lte_dur_dur | Gt_int_int | Gt_rat_rat | Gt_mon_mon | Gt_dat_dat
| Gt_dur_dur | Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dat_dat
| Gte_dur_dur | Eq_int_int | Eq_rat_rat | Eq_mon_mon | Eq_dat_dat
| Eq_dur_dur ) as op;
| Eq_dur_dur | HandleDefault | HandleDefaultOpt ) as op;
_;
} ->
evaluate_operator evaluate_expr ctx op m args

View File

@ -256,7 +256,8 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Marked.pos) :
| Length -> [array any] @-> it
| HandleDefault -> [array ([ut] @-> any); [ut] @-> bt; [ut] @-> any] @-> any
| HandleDefaultOpt ->
[array (option any); option bt; option any] @-> option any
[array (option any); [ut] @-> option bt; [ut] @-> option any]
@-> option any
in
Lazy.force ty

View File

@ -0,0 +1,703 @@
"""
.. module:: catala_runtime
:platform: Unix, Windows
:synopsis: The Python bindings for the functions used in the generated Catala code
:noindex:
.. moduleauthor:: Denis Merigoux <denis.merigoux@inria.fr>
"""
# This file should be in sync with compiler/runtime.{ml, mli} !
from gmpy2 import log2, mpz, mpq, mpfr, t_divmod, qdiv, f_div, sign # type: ignore
import datetime
import calendar
import dateutil.relativedelta
from typing import NewType, List, Callable, Tuple, Optional, TypeVar, Iterable, Union, Any
from functools import reduce
from enum import Enum
import copy
Alpha = TypeVar('Alpha')
Beta = TypeVar('Beta')
# ============
# Type classes
# ============
class Integer:
def __init__(self, value: Union[str, int]) -> None:
self.value = mpz(value)
def __add__(self, other: 'Integer') -> 'Integer':
return Integer(self.value + other.value)
def __sub__(self, other: 'Integer') -> 'Integer':
return Integer(self.value - other.value)
def __mul__(self, other: 'Integer') -> 'Integer':
return Integer(self.value * other.value)
def __truediv__(self, other: 'Integer') -> 'Decimal':
return Decimal (self.value) / Decimal (other.value)
def __neg__(self: 'Integer') -> 'Integer':
return Integer(- self.value)
def __lt__(self, other: 'Integer') -> bool:
return self.value < other.value
def __le__(self, other: 'Integer') -> bool:
return self.value <= other.value
def __gt__(self, other: 'Integer') -> bool:
return self.value > other.value
def __ge__(self, other: 'Integer') -> bool:
return self.value >= other.value
def __ne__(self, other: object) -> bool:
if isinstance(other, Integer):
return self.value != other.value
else:
return True
def __eq__(self, other: object) -> bool:
if isinstance(other, Integer):
return self.value == other.value
else:
return False
def __str__(self) -> str:
return self.value.__str__()
def __repr__(self) -> str:
return f"Integer({self.value.__repr__()})"
class Decimal:
def __init__(self, value: Union[str, int, float]) -> None:
self.value = mpq(value)
def __add__(self, other: 'Decimal') -> 'Decimal':
return Decimal(self.value + other.value)
def __sub__(self, other: 'Decimal') -> 'Decimal':
return Decimal(self.value - other.value)
def __mul__(self, other: 'Decimal') -> 'Decimal':
return Decimal(self.value * other.value)
def __truediv__(self, other: 'Decimal') -> 'Decimal':
return Decimal(self.value / other.value)
def __neg__(self: 'Decimal') -> 'Decimal':
return Decimal(- self.value)
def __lt__(self, other: 'Decimal') -> bool:
return self.value < other.value
def __le__(self, other: 'Decimal') -> bool:
return self.value <= other.value
def __gt__(self, other: 'Decimal') -> bool:
return self.value > other.value
def __ge__(self, other: 'Decimal') -> bool:
return self.value >= other.value
def __ne__(self, other: object) -> bool:
if isinstance(other, Decimal):
return self.value != other.value
else:
return True
def __eq__(self, other: object) -> bool:
if isinstance(other, Decimal):
return self.value == other.value
else:
return False
def __str__(self) -> str:
return "{}".format(mpfr(self.value))
def __repr__(self) -> str:
return f"Decimal({self.value.__repr__()})"
class Money:
def __init__(self, value: Integer) -> None:
self.value = value
def __add__(self, other: 'Money') -> 'Money':
return Money(self.value + other.value)
def __sub__(self, other: 'Money') -> 'Money':
return Money(self.value - other.value)
def __mul__(self, other: Decimal) -> 'Money':
cents = self.value.value
coeff = other.value
# TODO: change, does not work with negative values. Must divide the
# absolute values and then multiply by the resulting sign.
rat_result = self.value.value * other.value
out = Money(Integer(rat_result))
res, remainder = t_divmod(rat_result.numerator, rat_result.denominator)
if 2 * remainder >= rat_result.denominator:
return Money(Integer(res + 1))
else:
return Money(Integer(res))
def __truediv__(self, other: 'Money') -> Decimal:
if isinstance(other, Money):
return self.value / other.value
elif isinstance(other, Decimal):
return self * (1. / other.value)
else:
raise Exception("Dividing money and invalid obj")
def __neg__(self: 'Money') -> 'Money':
return Money(- self.value)
def __lt__(self, other: 'Money') -> bool:
return self.value < other.value
def __le__(self, other: 'Money') -> bool:
return self.value <= other.value
def __gt__(self, other: 'Money') -> bool:
return self.value > other.value
def __ge__(self, other: 'Money') -> bool:
return self.value >= other.value
def __ne__(self, other: object) -> bool:
if isinstance(other, Money):
return self.value != other.value
else:
return True
def __eq__(self, other: object) -> bool:
if isinstance(other, Money):
return self.value == other.value
else:
return False
def __str__(self) -> str:
return "${:.2}".format(self.value.value / 100)
def __repr__(self) -> str:
return f"Money({self.value.__repr__()})"
class Date:
def __init__(self, value: datetime.date) -> None:
self.value = value
def __add__(self, other: 'Duration') -> 'Date':
return Date(self.value + other.value)
def __sub__(self, other: object) -> object:
if isinstance(other, Date):
return Duration(dateutil.relativedelta.relativedelta(days=(self.value - other.value).days))
elif isinstance(other, Duration):
return Date(self.value - other.value)
else:
raise Exception("Substracting date and invalid obj")
def __lt__(self, other: 'Date') -> bool:
return self.value < other.value
def __le__(self, other: 'Date') -> bool:
return self.value <= other.value
def __gt__(self, other: 'Date') -> bool:
return self.value > other.value
def __ge__(self, other: 'Date') -> bool:
return self.value >= other.value
def __ne__(self, other: object) -> bool:
if isinstance(other, Date):
return self.value != other.value
else:
return True
def __eq__(self, other: object) -> bool:
if isinstance(other, Date):
return self.value == other.value
else:
return False
def __str__(self) -> str:
return self.value.__str__()
def __repr__(self) -> str:
return f"Date({self.value.__repr__()})"
class Duration:
def __init__(self, value: dateutil.relativedelta.relativedelta) -> None:
self.value = value
def __add__(self, other: 'Duration') -> 'Duration':
return Duration(self.value + other.value)
def __sub__(self, other: 'Duration') -> 'Duration':
return Duration(self.value - other.value)
def __neg__(self: 'Duration') -> 'Duration':
return Duration(- self.value)
def __truediv__(self, other: 'Duration') -> Decimal:
x = self.value.normalized()
y = other.value.normalized()
if (x.years != 0 or y.years != 0 or x.months != 0 or y.months != 0):
raise Exception("Can only divide durations expressed in days")
else:
return Decimal(x.days / y.days)
def __mul__(self: 'Duration', rhs: Integer) -> 'Duration':
return Duration(
dateutil.relativedelta.relativedelta(years=self.value.years * rhs.value,
months=self.value.months * rhs.value,
days=self.value.days * rhs.value))
def __lt__(self, other: 'Duration') -> bool:
x = self.value.normalized()
y = other.value.normalized()
if (x.years != 0 or y.years != 0 or x.months != 0 or y.months != 0):
raise Exception("Can only compare durations expressed in days")
else:
return x.days < y.days
def __le__(self, other: 'Duration') -> bool:
x = self.value.normalized()
y = other.value.normalized()
if (x.years != 0 or y.years != 0 or x.months != 0 or y.months != 0):
raise Exception("Can only compare durations expressed in days")
else:
return x.days <= y.days
def __gt__(self, other: 'Duration') -> bool:
x = self.value.normalized()
y = other.value.normalized()
if (x.years != 0 or y.years != 0 or x.months != 0 or y.months != 0):
raise Exception("Can only compare durations expressed in days")
else:
return x.days > y.days
def __ge__(self, other: 'Duration') -> bool:
x = self.value.normalized()
y = other.value.normalized()
if (x.years != 0 or y.years != 0 or x.months != 0 or y.months != 0):
raise Exception("Can only compare durations expressed in days")
else:
return x.days >= y.days
def __ne__(self, other: object) -> bool:
if isinstance(other, Duration):
return self.value != other.value
else:
return True
def __eq__(self, other: object) -> bool:
if isinstance(other, Duration):
return self.value == other.value
else:
return False
def __str__(self) -> str:
return self.value.__str__()
def __repr__(self) -> str:
return f"Duration({self.value.__repr__()})"
class Unit:
def __init__(self) -> None:
...
def __eq__(self, other: object) -> bool:
if isinstance(other, Unit):
return True
else:
return False
def __ne__(self, other: object) -> bool:
if isinstance(other, Unit):
return False
else:
return True
def __str__(self) -> str:
return "()"
def __repr__(self) -> str:
return "Unit()"
class SourcePosition:
def __init__(self,
filename: str,
start_line: int,
start_column: int,
end_line: int,
end_column: int,
law_headings: List[str]) -> None:
self.filename = filename
self.start_line = start_line
self.start_column = start_column
self.end_line = end_line
self.end_column = end_column
self.law_headings = law_headings
def __str__(self) -> str:
return "in file {}, from {}:{} to {}:{}".format(
self.filename, self.start_line, self.start_column, self.end_line, self.end_column)
# ==========
# Exceptions
# ==========
class EmptyError(Exception):
pass
class AssertionFailed(Exception):
def __init__(self, source_position: SourcePosition) -> None:
self.source_position = source_position
class ConflictError(Exception):
def __init__(self, source_position: SourcePosition) -> None:
self.source_position = source_position
class NoValueProvided(Exception):
def __init__(self, source_position: SourcePosition) -> None:
self.source_position = source_position
class AssertionFailure(Exception):
def __init__(self, source_position: SourcePosition) -> None:
self.source_position = source_position
# ============================
# Constructors and conversions
# ============================
# -----
# Money
# -----
def money_of_cents_string(v: str) -> Money:
return Money(Integer(v))
def money_of_units_int(v: int) -> Money:
return Money(Integer(v) * Integer(100))
def money_of_cents_integer(v: Integer) -> Money:
return Money(v)
def money_to_float(m: Money) -> float:
return float(mpfr(mpq(m.value.value, 100)))
def money_to_string(m: Money) -> str:
return str(money_to_float(m))
def money_to_cents(m: Money) -> Integer:
return m.value
def money_round(m: Money) -> Money:
res, remainder = t_divmod(m.value.value, 100)
if remainder < 50:
return Money(Integer(res * 100))
else:
return Money(Integer((res + sign(res)) * 100))
def money_of_decimal(d: Decimal) -> Money:
"""
Warning: rounds to the nearest cent
"""
return Money(Integer(mpz(d.value)))
# --------
# Decimals
# --------
def decimal_of_string(d: str) -> Decimal:
return Decimal(d)
def decimal_to_float(d: Decimal) -> float:
return float(mpfr(d.value))
def decimal_of_float(d: float) -> Decimal:
return Decimal(d)
def decimal_of_integer(d: Integer) -> Decimal:
return Decimal(d.value)
def decimal_to_string(precision: int, i: Decimal) -> str:
return "{1:.{0}}".format(precision, mpfr(i.value, precision * 10 // 2))
def decimal_round(q: Decimal) -> Decimal:
"""
Implements the workaround by
https://gmplib.org/list-archives/gmp-discuss/2009-May/003767.html
"""
return Decimal(
mpq(f_div(2*q.value.numerator + q.value.denominator,
2*q.value.denominator), 1) # type:ignore
)
def decimal_of_money(m: Money) -> Decimal:
return Decimal(mpq(qdiv(m.value.value, 100)))
# --------
# Integers
# --------
def integer_of_string(s: str) -> Integer:
return Integer(s)
def integer_to_string(d: Integer) -> str:
return str(d.value)
def integer_of_int(d: int) -> Integer:
return Integer(d)
def integer_to_int(d: Integer) -> int:
return int(d.value)
def integer_exponentiation(i: Integer, e: int) -> Integer:
return i ** e # type: ignore
def integer_log2(i: Integer) -> int:
return int(log2(i.value))
# -----
# Dates
# -----
def day_of_month_of_date(d: Date) -> Integer:
return integer_of_int(d.value.day)
def month_number_of_date(d: Date) -> Integer:
return integer_of_int(d.value.month)
def year_of_date(d: Date) -> Integer:
return integer_of_int(d.value.year)
def date_to_string(d: Date) -> str:
return "{}".format(d.value)
def date_of_numbers(year: int, month: int, day: int) -> Date:
# The datetime.date does not take year=0 as an entry, we trick it into
# 1 in that case because year=0 cases don't care about the actual year
return Date(datetime.date(year if year != 0 else 1, month, day))
def date_of_datetime(d: datetime.date) -> Date:
return Date(d)
def first_day_of_month(d: Date) -> Date:
return Date(datetime.date(d.value.year, d.value.month, 1))
def last_day_of_month(d: Date) -> Date:
return Date(datetime.date(d.value.year, d.value.month, calendar.monthrange(d.value.year, d.value.month)[1]))
# ---------
# Durations
# ---------
def duration_of_numbers(years: int, months: int, days: int) -> Duration:
return Duration(dateutil.relativedelta.relativedelta(years=years, months=months, days=days))
def duration_to_years_months_days(d: Duration) -> Tuple[int, int, int]:
return (d.value.years, d.value.months, d.value.days) # type: ignore
def duration_to_string(s: Duration) -> str:
return "{}".format(s.value)
# -----
# Lists
# -----
def list_fold_left(f: Callable[[Alpha, Beta], Alpha], init: Alpha, l: List[Beta]) -> Alpha:
return reduce(f, l, init)
def list_filter(f: Callable[[Alpha], bool], l: List[Alpha]) -> List[Alpha]:
return [i for i in l if f(i)]
def list_map(f: Callable[[Alpha], Beta], l: List[Alpha]) -> List[Beta]:
return [f(i) for i in l]
def list_reduce(f: Callable[[Alpha, Alpha], Alpha], dft: Alpha, l: List[Alpha]) -> Alpha:
if l == []:
return dft
else:
return reduce(f, l)
def list_length(l: List[Alpha]) -> Integer:
return Integer(len(l))
# ========
# Defaults
# ========
def handle_default(
pos: SourcePosition,
exceptions: List[Callable[[Unit], Alpha]],
just: Callable[[Unit], Alpha],
cons: Callable[[Unit], Alpha]
) -> Alpha:
acc: Optional[Alpha] = None
for exception in exceptions:
new_val: Optional[Alpha]
try:
new_val = exception(Unit())
except EmptyError:
new_val = None
if acc is None:
acc = new_val
elif not (acc is None) and new_val is None:
pass # acc stays the same
elif not (acc is None) and not (new_val is None):
raise ConflictError(pos)
if acc is None:
if just(Unit()):
return cons(Unit())
else:
raise EmptyError
else:
return acc
def handle_default_opt(
pos: SourcePosition,
exceptions: List[Optional[Any]],
just: Optional[bool],
cons: Optional[Alpha]
) -> Optional[Alpha]:
acc: Optional[Alpha] = None
for exception in exceptions:
if acc is None:
acc = exception
elif not (acc is None) and exception is None:
pass # acc stays the same
elif not (acc is None) and not (exception is None):
raise ConflictError(pos)
if acc is None:
if just is None:
return None
else:
if just:
return cons
else:
return None
else:
return acc
def no_input() -> Callable[[Unit], Alpha]:
def closure(_: Unit):
raise EmptyError
return closure
# This value is used for the Python code generation to trump mypy and forcing
# it to accept dead code. Indeed, when raising an exception during a variable
# definition, mypy complains that the later dead code will not know what
# this variable was. So we give this variable a dead value.
dead_value: Any = 0
# =======
# Logging
# =======
class LogEventCode(Enum):
VariableDefinition = 0
BeginCall = 1
EndCall = 2
DecisionTaken = 3
class LogEvent:
def __init__(self, code: LogEventCode, payload: Union[List[str], SourcePosition, Tuple[List[str], Alpha]]) -> None:
self.code = code
self.payload = payload
log: List[LogEvent] = []
def reset_log():
log = []
def retrieve_log() -> List[LogEvent]:
return log
def log_variable_definition(headings: List[str], value: Alpha) -> Alpha:
log.append(LogEvent(LogEventCode.VariableDefinition,
(headings, copy.deepcopy(value))))
return value
def log_begin_call(headings: List[str], f: Callable[[Alpha], Beta], value: Alpha) -> Beta:
log.append(LogEvent(LogEventCode.BeginCall, headings))
return f(value)
def log_end_call(headings: List[str], value: Alpha) -> Alpha:
log.append(LogEvent(LogEventCode.EndCall, headings))
return value
def log_decision_taken(pos: SourcePosition, value: bool) -> bool:
log.append(LogEvent(LogEventCode.DecisionTaken, pos))
return value

View File

@ -610,8 +610,8 @@ let handle_default :
let handle_default_opt
(pos : source_position)
(exceptions : 'a eoption array)
(just : bool eoption)
(cons : 'a eoption) : 'a eoption =
(just : unit -> bool eoption)
(cons : unit -> 'a eoption) : 'a eoption =
let except =
Array.fold_left
(fun acc except ->
@ -624,8 +624,8 @@ let handle_default_opt
match except with
| ESome _ -> except
| ENone _ -> (
match just with
| ESome b -> if b then cons else ENone ()
match just () with
| ESome b -> if b then cons () else ENone ()
| ENone _ -> ENone ())
let no_input : unit -> 'a = fun _ -> raise EmptyError

View File

@ -26,7 +26,6 @@ type date [@@deriving yojson]
type duration [@@deriving yojson]
type date_rounding = Dates_calc.Dates.date_rounding
type source_position = {
filename : string;
start_line : int;
@ -278,8 +277,8 @@ val handle_default :
val handle_default_opt :
source_position ->
'a eoption array ->
bool eoption ->
'a eoption ->
(unit -> bool eoption) ->
(unit -> 'a eoption) ->
'a eoption
(** @raise ConflictError *)

View File

@ -618,8 +618,8 @@ def handle_default(
def handle_default_opt(
pos: SourcePosition,
exceptions: List[Optional[Any]],
just: Optional[bool],
cons: Optional[Alpha]
just: Callable[[Unit],Optional[bool]],
cons: Callable[[Unit],Optional[Alpha]]
) -> Optional[Alpha]:
acc: Optional[Alpha] = None
for exception in exceptions:
@ -630,11 +630,12 @@ def handle_default_opt(
elif not (acc is None) and not (exception is None):
raise ConflictError(pos)
if acc is None:
if just is None:
b = just(Unit())
if b is None:
return None
else:
if just:
return cons
if b:
return cons(Unit())
else:
return None
else: