Python runtime OK now

This commit is contained in:
Denis Merigoux 2023-05-26 17:08:26 +02:00
parent 8987d358e7
commit fa9f432e8b
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
2 changed files with 38 additions and 34 deletions

View File

@ -13,7 +13,6 @@
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations under
the License. *)
[@@@warning "-32-27"]
open Catala_utils
open Shared_ast
@ -42,16 +41,9 @@ let format_lit (fmt : Format.formatter) (l : lit Mark.pos) : unit =
let years, months, days = Runtime.duration_to_years_months_days d in
Format.fprintf fmt "duration_of_numbers(%d,%d,%d)" years months days
let format_log_entry (fmt : Format.formatter) (entry : log_entry) : unit =
match entry with
| VarDef _ -> Format.pp_print_string fmt ":="
| BeginCall -> Format.pp_print_string fmt ""
| EndCall -> Format.fprintf fmt "%s" ""
| PosRecordIfTrueBool -> Format.pp_print_string fmt ""
let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit =
match Mark.remove op with
| Log (entry, infos) -> assert false
| Log (_entry, _infos) -> assert false
| Minus_int | Minus_rat | Minus_mon | Minus_dur ->
Format.pp_print_string fmt "-"
(* Todo: use the names from [Operator.name] *)
@ -247,14 +239,6 @@ let format_func_name (fmt : Format.formatter) (v : FuncName.t) : unit =
let v_str = Mark.remove (FuncName.get_info v) in
format_name_cleaned fmt v_str
let format_var_name (fmt : Format.formatter) (v : VarName.t) : unit =
Format.fprintf fmt "%a_%s" VarName.format_t v (string_of_int (VarName.hash v))
let needs_parens (e : expr) : bool =
match Mark.remove e with
| ELit (LBool _ | LUnit) | EVar _ | EOp _ -> false
| _ -> true
let format_exception (fmt : Format.formatter) (exc : except Mark.pos) : unit =
let pos = Mark.get exc in
match Mark.remove exc with
@ -325,9 +309,16 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) :
when !Cli.trace_flag ->
Format.fprintf fmt "log_begin_call(%a,@ %a,@ %a)" format_uid_list info
(format_expression ctx) f (format_expression ctx) arg
| EApp ((EOp (Log (VarDef tau, info)), _), [arg1]) when !Cli.trace_flag ->
Format.fprintf fmt "log_variable_definition(%a,@ %a)" format_uid_list info
(format_expression ctx) arg1
| EApp ((EOp (Log (VarDef var_def_info, info)), _), [arg1])
when !Cli.trace_flag ->
Format.fprintf fmt
"log_variable_definition(%a,@ LogIO(io_input=%s,@ io_output=%b),@ %a)"
format_uid_list info
(match var_def_info.log_io_input with
| Runtime.NoInput -> "NoInput"
| Runtime.OnlyInput -> "OnlyInput"
| Runtime.Reentrant -> "Reentrant")
var_def_info.log_io_output (format_expression ctx) arg1
| EApp ((EOp (Log (PosRecordIfTrueBool, _)), pos), [arg1])
when !Cli.trace_flag ->
Format.fprintf fmt
@ -556,7 +547,7 @@ let format_ctx
format_enum_name enum_name
(Format.pp_print_list
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt (i, enum_cons, enum_cons_type) ->
(fun fmt (i, enum_cons, _enum_cons_type) ->
Format.fprintf fmt "%a = %d" format_enum_cons_name enum_cons i))
(List.mapi
(fun i (x, y) -> i, x, y)

View File

@ -40,7 +40,7 @@ class Integer:
return Integer(self.value * other.value)
def __truediv__(self, other: 'Integer') -> 'Decimal':
return Decimal (self.value) / Decimal (other.value)
return Decimal(self.value) / Decimal(other.value)
def __neg__(self: 'Integer') -> 'Integer':
return Integer(- self.value)
@ -155,7 +155,7 @@ class Money:
elif isinstance(other, Decimal):
return self * (1. / other.value)
else:
raise Exception("Dividing money and invalid obj")
raise Exception("Dividing money and invalid obj")
def __neg__(self: 'Money') -> 'Money':
return Money(- self.value)
@ -200,11 +200,11 @@ class Date:
def __sub__(self, other: object) -> object:
if isinstance(other, Date):
return Duration(dateutil.relativedelta.relativedelta(days=(self.value - other.value).days))
return Duration(dateutil.relativedelta.relativedelta(days=(self.value - other.value).days))
elif isinstance(other, Duration):
return Date(self.value - other.value)
return Date(self.value - other.value)
else:
raise Exception("Substracting date and invalid obj")
raise Exception("Substracting date and invalid obj")
def __lt__(self, other: 'Date') -> bool:
return self.value < other.value
@ -618,8 +618,8 @@ def handle_default(
def handle_default_opt(
pos: SourcePosition,
exceptions: List[Optional[Any]],
just: Callable[[Unit],Optional[bool]],
cons: Callable[[Unit],Optional[Alpha]]
just: Callable[[Unit], Optional[bool]],
cons: Callable[[Unit], Optional[Alpha]]
) -> Optional[Alpha]:
acc: Optional[Alpha] = None
for exception in exceptions:
@ -666,9 +666,22 @@ class LogEventCode(Enum):
DecisionTaken = 3
class InputIO(Enum):
NoInput = 0
OnlyInput = 1
Reentrant = 2
class LogIO:
def __init__(self, input_io: InputIO, output_io: bool):
self.input_io = input_io
self.output_io = output_io
class LogEvent:
def __init__(self, code: LogEventCode, payload: Union[List[str], SourcePosition, Tuple[List[str], Alpha]]) -> None:
def __init__(self, code: LogEventCode, io: Optional[LogIO], payload: Union[List[str], SourcePosition, Tuple[List[str], Alpha]]) -> None:
self.code = code
self.io = io
self.payload = payload
@ -683,22 +696,22 @@ def retrieve_log() -> List[LogEvent]:
return log
def log_variable_definition(headings: List[str], value: Alpha) -> Alpha:
log.append(LogEvent(LogEventCode.VariableDefinition,
def log_variable_definition(headings: List[str], io: LogIO, value: Alpha) -> Alpha:
log.append(LogEvent(LogEventCode.VariableDefinition, io,
(headings, copy.deepcopy(value))))
return value
def log_begin_call(headings: List[str], f: Callable[[Alpha], Beta], value: Alpha) -> Beta:
log.append(LogEvent(LogEventCode.BeginCall, headings))
log.append(LogEvent(LogEventCode.BeginCall, None, headings))
return f(value)
def log_end_call(headings: List[str], value: Alpha) -> Alpha:
log.append(LogEvent(LogEventCode.EndCall, headings))
log.append(LogEvent(LogEventCode.EndCall, None, headings))
return value
def log_decision_taken(pos: SourcePosition, value: bool) -> bool:
log.append(LogEvent(LogEventCode.DecisionTaken, pos))
log.append(LogEvent(LogEventCode.DecisionTaken, None, pos))
return value