diff --git a/Makefile b/Makefile index f67c3806..f1eb5599 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,9 @@ help : Makefile ROOT_DIR:=$(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) +# Export all variables to sub-make +export + ########################################## # Dependencies ########################################## @@ -189,12 +192,12 @@ $(FRENCH_LAW_PYTHON_LIB_DIR)/allocations_familiales.py: .FORCE $(FRENCH_LAW_PYTHON_LIB_DIR)/ $(FRENCH_LAW_OCAML_LIB_DIR)/law_source/allocations_familiales.ml: .FORCE - CATALA_OPTS="-O -t" $(MAKE) -C $(ALLOCATIONS_FAMILIALES_DIR) allocations_familiales.ml + CATALA_OPTS="$(CATALA_OPTS) -O -t" $(MAKE) -C $(ALLOCATIONS_FAMILIALES_DIR) allocations_familiales.ml cp -f $(ALLOCATIONS_FAMILIALES_DIR)/allocations_familiales.ml \ $(FRENCH_LAW_OCAML_LIB_DIR)/law_source $(FRENCH_LAW_OCAML_LIB_DIR)/law_source/unit_tests/tests_allocations_familiales.ml: .FORCE - CATALA_OPTS="-O -t" $(MAKE) -s -C $(ALLOCATIONS_FAMILIALES_DIR) tests/tests_allocations_familiales.ml + CATALA_OPTS="$(CATALA_OPTS) -O -t" $(MAKE) -s -C $(ALLOCATIONS_FAMILIALES_DIR) tests/tests_allocations_familiales.ml cp -f $(ALLOCATIONS_FAMILIALES_DIR)/tests/tests_allocations_familiales.ml \ $(FRENCH_LAW_OCAML_LIB_DIR)/law_source/unit_tests/ diff --git a/compiler/driver.ml b/compiler/driver.ml index 893cc1b4..02c3d573 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -89,6 +89,7 @@ let driver (source_file : Pos.input_file) (debug : bool) (unstyled : bool) | Some f -> f | None -> Filename.remove_extension source_file ^ ".d" in + Cli.debug_print (Format.asprintf "Writing list of dependencies to %s..." output_file); let oc = open_out output_file in Printf.fprintf oc "%s:\\\n%s\n%s:" (String.concat "\\\n" diff --git a/compiler/scalc/compile_from_lambda.ml b/compiler/scalc/compile_from_lambda.ml index f3be12ee..39b465f0 100644 --- a/compiler/scalc/compile_from_lambda.ml +++ b/compiler/scalc/compile_from_lambda.ml @@ -138,6 +138,11 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.bloc | L.EAbs ((binder, binder_pos), taus) -> let vars, body = Bindlib.unmbind binder in let vars_tau = List.map2 (fun x tau -> (x, tau)) (Array.to_list vars) taus in + let closure_name = + match ctxt.inside_definition_of with + | None -> A.LocalName.fresh ("closure", Pos.get_position block_expr) + | Some x -> x + in let ctxt = { ctxt with @@ -146,13 +151,9 @@ and translate_statements (ctxt : ctxt) (block_expr : L.expr Pos.marked) : A.bloc (fun var_dict (x, _) -> L.VarMap.add x (A.LocalName.fresh (Bindlib.name_of x, binder_pos)) var_dict) ctxt.var_dict vars_tau; + inside_definition_of = None; } in - let closure_name = - match ctxt.inside_definition_of with - | None -> A.LocalName.fresh ("closure", Pos.get_position block_expr) - | Some x -> x - in let new_body = translate_statements ctxt body in [ ( A.SInnerFuncDef diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index 49af6f4d..7046f221 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -21,7 +21,8 @@ module L = Lcalc.Ast let format_lit (fmt : Format.formatter) (l : L.lit Pos.marked) : unit = match Pos.unmark l with - | LBool b -> Dcalc.Print.format_lit fmt (Pos.same_pos_as (Dcalc.Ast.LBool b) l) + | LBool true -> Format.fprintf fmt "True" + | LBool false -> Format.fprintf fmt "False" | LInt i -> Format.fprintf fmt "integer_of_string(\"%s\")" (Runtime.integer_to_string i) | LUnit -> Format.fprintf fmt "Unit()" | LRat i -> @@ -61,23 +62,23 @@ let format_binop (fmt : Format.formatter) (op : Dcalc.Ast.binop Pos.marked) : un | Lte _ -> Format.fprintf fmt "<=" | Gt _ -> Format.fprintf fmt ">" | Gte _ -> Format.fprintf fmt ">=" - | Map -> Format.fprintf fmt "Array.map" - | Filter -> Format.fprintf fmt "array_filter" + | Map -> Format.fprintf fmt "list_map" + | Filter -> Format.fprintf fmt "list_filter" let format_ternop (fmt : Format.formatter) (op : Dcalc.Ast.ternop Pos.marked) : unit = - match Pos.unmark op with Fold -> Format.fprintf fmt "Array.fold_left" + match Pos.unmark op with Fold -> Format.fprintf fmt "list_fold_left" let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list) : unit = - Format.fprintf fmt "@[[%a]@]" + Format.fprintf fmt "@[[%a]@]" (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt info -> Format.fprintf fmt "\"%a\"" Utils.Uid.MarkedString.format_info info)) uids let format_string_list (fmt : Format.formatter) (uids : string list) : unit = - Format.fprintf fmt "@[[%a]@]" + Format.fprintf fmt "@[[%a]@]" (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ";@ ") + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt info -> Format.fprintf fmt "\"%s\"" info)) uids @@ -85,10 +86,8 @@ let format_unop (fmt : Format.formatter) (op : Dcalc.Ast.unop Pos.marked) : unit match Pos.unmark op with | Minus _ -> Format.fprintf fmt "-" | Not -> Format.fprintf fmt "not" - | Log (entry, infos) -> - Format.fprintf fmt "@[log_entry@ \"%a|%a\"@]" format_log_entry entry format_uid_list - infos - | Length -> Format.fprintf fmt "%s" "array_length" + | Log (entry, infos) -> assert false (* should not happen *) + | Length -> Format.fprintf fmt "%s" "len" | IntToRat -> Format.fprintf fmt "%s" "decimal_of_integer" | GetDay -> Format.fprintf fmt "%s" "day_of_month_of_date" | GetMonth -> Format.fprintf fmt "%s" "month_number_of_date" @@ -124,20 +123,6 @@ let format_enum_cons_name (fmt : Format.formatter) (v : Dcalc.Ast.EnumConstructo Format.fprintf fmt "%s" (avoid_keywords (to_ascii (Format.asprintf "%a" Dcalc.Ast.EnumConstructor.format_t v))) -let rec typ_embedding_name (fmt : Format.formatter) (ty : D.typ Pos.marked) : unit = - match Pos.unmark ty with - | D.TLit D.TUnit -> Format.fprintf fmt "embed_unit" - | D.TLit D.TBool -> Format.fprintf fmt "embed_bool" - | D.TLit D.TInt -> Format.fprintf fmt "embed_integer" - | D.TLit D.TRat -> Format.fprintf fmt "embed_decimal" - | D.TLit D.TMoney -> Format.fprintf fmt "embed_money" - | D.TLit D.TDate -> Format.fprintf fmt "embed_date" - | D.TLit D.TDuration -> Format.fprintf fmt "embed_duration" - | D.TTuple (_, Some s_name) -> Format.fprintf fmt "embed_%a" format_struct_name s_name - | D.TEnum (_, e_name) -> Format.fprintf fmt "embed_%a" format_enum_name e_name - | D.TArray ty -> Format.fprintf fmt "embed_array (%a)" typ_embedding_name ty - | _ -> Format.fprintf fmt "unembeddable" - let typ_needs_parens (e : Dcalc.Ast.typ Pos.marked) : bool = match Pos.unmark e with TArrow _ | TArray _ -> true | _ -> false @@ -166,31 +151,24 @@ let rec format_typ (fmt : Format.formatter) (typ : Dcalc.Ast.typ Pos.marked) : u | TArrow (t1, t2) -> Format.fprintf fmt "Callable[[%a], %a]" format_typ_with_parens t1 format_typ_with_parens t2 | TArray t1 -> Format.fprintf fmt "List[%a]" format_typ_with_parens t1 - | TAny -> Format.fprintf fmt "_" + | TAny -> Format.fprintf fmt "Any" + +let format_name_cleaned (fmt : Format.formatter) (s : string) : unit = + let lowercase_name = to_lowercase (to_ascii s) in + let lowercase_name = + Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") lowercase_name + in + let lowercase_name = avoid_keywords (to_ascii lowercase_name) in + Format.fprintf fmt "%s" lowercase_name let format_var (fmt : Format.formatter) (v : LocalName.t) : unit = let v_str = Pos.unmark (LocalName.get_info v) in - let lowercase_name = to_lowercase (to_ascii v_str) in - let lowercase_name = - Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") lowercase_name - in - let lowercase_name = avoid_keywords (to_ascii lowercase_name) in - if lowercase_name = "handle_default" || Dcalc.Print.begins_with_uppercase v_str then - Format.fprintf fmt "%s" lowercase_name - else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name - else Format.fprintf fmt "%s_" lowercase_name + if v_str = "_" then Format.fprintf fmt "_" + else Format.fprintf fmt "%a_%d" format_name_cleaned v_str (LocalName.hash v) -let format_func_name (fmt : Format.formatter) (v : TopLevelName.t) : unit = +let format_toplevel_name (fmt : Format.formatter) (v : TopLevelName.t) : unit = let v_str = Pos.unmark (TopLevelName.get_info v) in - let lowercase_name = to_lowercase (to_ascii v_str) in - let lowercase_name = - Re.Pcre.substitute ~rex:(Re.Pcre.regexp "\\.") ~subst:(fun _ -> "_dot_") lowercase_name - in - let lowercase_name = avoid_keywords (to_ascii lowercase_name) in - if lowercase_name = "handle_default" || Dcalc.Print.begins_with_uppercase v_str then - Format.fprintf fmt "%s" lowercase_name - else if lowercase_name = "_" then Format.fprintf fmt "%s" lowercase_name - else Format.fprintf fmt "%s_" lowercase_name + format_name_cleaned fmt v_str let needs_parens (e : expr Pos.marked) : bool = match Pos.unmark e with ELit (LBool _ | LUnit) | EVar _ | EOp _ -> false | _ -> true @@ -203,75 +181,121 @@ let format_exception (fmt : Format.formatter) (exc : L.except Pos.marked) : unit | NoValueProvided -> let pos = Pos.get_position exc in Format.fprintf fmt - "NoValueProvided(SourcePosition(filename = \"%s\",@ start_line=%d,@ start_column=%d,@ \ + "NoValueProvided(SourcePosition(filename=\"%s\",@ start_line=%d,@ start_column=%d,@ \ end_line=%d,@ end_column=%d,@ law_headings=%a))" (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos) -let rec format_expr (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) : unit - = - let format_expr = format_expr ctx in - let format_with_parens (fmt : Format.formatter) (e : expr Pos.marked) = - if needs_parens e then Format.fprintf fmt "(%a)" format_expr e - else Format.fprintf fmt "%a" format_expr e - in +let rec format_expression (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (e : expr Pos.marked) + : unit = match Pos.unmark e with - | EFunc v -> Format.fprintf fmt "%a" format_func_name v - | EVar v -> Format.fprintf fmt "%a" format_var v + | EVar v -> format_var fmt v + | EFunc f -> format_toplevel_name fmt f | EStruct (es, s) -> if List.length es = 0 then failwith "should not happen" else - Format.fprintf fmt "%a(@[%a@])" format_struct_name s + Format.fprintf fmt "@[%a(%a)@]" format_struct_name s (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (fun fmt (e, struct_field) -> - Format.fprintf fmt "%a = %a" format_struct_field_name struct_field format_with_parens - e)) + Format.fprintf fmt "%a = %a" format_struct_field_name struct_field + (format_expression ctx) e)) (List.combine es (List.map fst (Dcalc.Ast.StructMap.find s ctx.ctx_structs))) + | EStructFieldAccess (e1, field, _) -> + Format.fprintf fmt "%a.%a" (format_expression ctx) e1 format_struct_field_name field + | EInj (e, cons, enum_name) -> + Format.fprintf fmt "@[%a_%a(%a)@]" format_enum_name enum_name format_enum_cons_name + cons (format_expression ctx) e | EArray es -> - Format.fprintf fmt "@[[%a]@]" + Format.fprintf fmt "@[[%a]@]" (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (fun fmt e -> Format.fprintf fmt "%a" format_with_parens e)) + (fun fmt e -> Format.fprintf fmt "%a" (format_expression ctx) e)) es - | EStructFieldAccess (e1, field, _) -> - Format.fprintf fmt "%a.%a" format_with_parens e1 format_struct_field_name field - | EInj (e, cons, _) -> - Format.fprintf fmt "@[%a(%a)@]" format_enum_cons_name cons format_expr e | ELit l -> Format.fprintf fmt "%a" format_lit (Pos.same_pos_as l e) | EApp ((EOp (Binop ((Dcalc.Ast.Map | Dcalc.Ast.Filter) as op)), _), [ arg1; arg2 ]) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" format_binop (op, Pos.no_pos) format_with_parens - arg1 format_with_parens arg2 + Format.fprintf fmt "@[%a(%a,@ %a)@]" format_binop (op, Pos.no_pos) + (format_expression ctx) arg1 (format_expression ctx) arg2 | EApp ((EOp (Binop op), _), [ arg1; arg2 ]) -> - Format.fprintf fmt "@[%a@ %a@ %a@]" format_with_parens arg1 format_binop - (op, Pos.no_pos) format_with_parens arg2 + Format.fprintf fmt "@[(%a %a %a)@]" (format_expression ctx) arg1 format_binop + (op, Pos.no_pos) (format_expression ctx) arg2 | EApp ((EApp ((EOp (Unop (D.Log (D.BeginCall, info))), _), [ f ]), _), [ arg ]) when !Cli.trace_flag -> - Format.fprintf fmt "(log_begin_call@ %a@ %a@ %a)" format_uid_list info format_with_parens f - format_with_parens arg + Format.fprintf fmt "@[log_begin_call(%a,@ %a,@ %a)@]" format_uid_list info + (format_expression ctx) f (format_expression ctx) arg | EApp ((EOp (Unop (D.Log (D.VarDef tau, info))), _), [ arg1 ]) when !Cli.trace_flag -> - Format.fprintf fmt "(log_variable_definition@ %a@ (%a)@ %a)" format_uid_list info - typ_embedding_name (tau, Pos.no_pos) format_with_parens arg1 + Format.fprintf fmt "@[log_variable_definition(%a,@ %a)@]" format_uid_list info + (format_expression ctx) arg1 | EApp ((EOp (Unop (D.Log (D.PosRecordIfTrueBool, _))), pos), [ arg1 ]) when !Cli.trace_flag -> Format.fprintf fmt - "(log_decision_taken@ @[{filename = \"%s\";@ start_line=%d;@ start_column=%d;@ \ - end_line=%d; end_column=%d;@ law_headings=%a}@]@ %a)" + "@[log_decision_taken(SourcePosition(filename=\"%s\",@ start_line=%d,@ \ + start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)@]" (Pos.get_file pos) (Pos.get_start_line pos) (Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos) format_string_list (Pos.get_law_info pos) - format_with_parens arg1 + (format_expression ctx) arg1 | EApp ((EOp (Unop (D.Log (D.EndCall, info))), _), [ arg1 ]) when !Cli.trace_flag -> - Format.fprintf fmt "(log_end_call@ %a@ %a)" format_uid_list info format_with_parens arg1 - | EApp ((EOp (Unop (D.Log _)), _), [ arg1 ]) -> Format.fprintf fmt "%a" format_with_parens arg1 + Format.fprintf fmt "@[log_end_call(%a,@ %a)@]" format_uid_list info + (format_expression ctx) arg1 + | EApp ((EOp (Unop (D.Log _)), _), [ arg1 ]) -> + Format.fprintf fmt "%a" (format_expression ctx) arg1 + | EApp ((EOp (Unop ((Minus _ | Not) as op)), _), [ arg1 ]) -> + Format.fprintf fmt "@[%a %a@]" format_unop (op, Pos.no_pos) (format_expression ctx) + arg1 | EApp ((EOp (Unop op), _), [ arg1 ]) -> - Format.fprintf fmt "@[%a@ %a@]" format_unop (op, Pos.no_pos) format_with_parens arg1 + Format.fprintf fmt "@[%a(%a)@]" format_unop (op, Pos.no_pos) (format_expression ctx) + arg1 | EApp (f, args) -> - Format.fprintf fmt "@[%a(%a)@]" format_with_parens f - (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") format_expr) + Format.fprintf fmt "@[%a(%a)@]" (format_expression ctx) f + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + (format_expression ctx)) args | EOp (Ternop op) -> Format.fprintf fmt "%a" format_ternop (op, Pos.no_pos) | EOp (Binop op) -> Format.fprintf fmt "%a" format_binop (op, Pos.no_pos) | EOp (Unop op) -> Format.fprintf fmt "%a" format_unop (op, Pos.no_pos) +let rec format_statement (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (s : stmt Pos.marked) : + unit = + match Pos.unmark s with + | SInnerFuncDef (name, { func_params; func_body }) -> + Format.fprintf fmt "@[def %a(%a):@\n%a@]" format_var (Pos.unmark name) + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + (fun fmt (var, typ) -> + Format.fprintf fmt "%a:%a" format_var (Pos.unmark var) format_typ typ)) + func_params (format_block ctx) func_body + | SLocalDecl _ -> () (* We don't need to declare variables in Python *) + | SLocalDef (v, e) -> + Format.fprintf fmt "%a = %a" format_var (Pos.unmark v) (format_expression ctx) e + | STryExcept (try_b, except, catch_b) -> + Format.fprintf fmt "@[try:@\n%a@]@\n@[except %a:@\n%a@]" (format_block ctx) + try_b format_exception (except, Pos.no_pos) (format_block ctx) catch_b + | SRaise except -> + Format.fprintf fmt "@[raise %a@]" format_exception (except, Pos.get_position s) + | SIfThenElse (cond, b1, b2) -> + Format.fprintf fmt "@[if %a:@\n%a@]@\n@[else:@\n%a@]" (format_expression ctx) + cond (format_block ctx) b1 (format_block ctx) b2 + | SSwitch (e1, e_name, cases) -> + let cases = + List.map2 (fun (x, y) (cons, _) -> (x, y, cons)) cases (D.EnumMap.find e_name ctx.ctx_enums) + in + let tmp_var = LocalName.fresh ("match_arg", Pos.no_pos) in + Format.fprintf fmt "%a = %a@\n@[if %a@]" format_var tmp_var (format_expression ctx) e1 + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt "@]@\n@[elif ") + (fun fmt (case_block, payload_var, cons_name) -> + Format.fprintf fmt "%a is %a_%a:@\n%a = %a.value@\n%a" format_var tmp_var + format_enum_name e_name format_enum_cons_name cons_name format_var payload_var + format_var tmp_var (format_block ctx) case_block)) + cases + | SReturn e1 -> + Format.fprintf fmt "@[return %a@]" (format_expression ctx) (e1, Pos.get_position s) + | SAssert e1 -> + Format.fprintf fmt "@[assert %a@]" (format_expression ctx) (e1, Pos.get_position s) + +and format_block (ctx : Dcalc.Ast.decl_ctx) (fmt : Format.formatter) (b : block) : unit = + Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n") (format_statement ctx) fmt b + let format_ctx (type_ordering : Scopelang.Dependency.TVertex.t list) (fmt : Format.formatter) (ctx : D.decl_ctx) : unit = let format_struct_decl fmt (struct_name, struct_fields) = @@ -340,6 +364,7 @@ let format_program (fmt : Format.formatter) (p : Ast.program) "# This file has been generated by the Catala compiler, do not edit!\n\ @\n\ from .catala_runtime import *@\n\ + from typing import Any, List, Callable, Tuple\n\ @\n\ %a@\n\ @\n\ @@ -347,5 +372,11 @@ let format_program (fmt : Format.formatter) (p : Ast.program) (format_ctx type_ordering) p.decl_ctx (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n@\n") - (fun fmt (name, { Ast.func_params; Ast.func_body }) -> assert false)) + (fun fmt (name, { Ast.func_params; Ast.func_body }) -> + Format.fprintf fmt "@[def %a(%a):@\n%a@]" format_toplevel_name name + (Format.pp_print_list + ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + (fun fmt (var, typ) -> + Format.fprintf fmt "%a:%a" format_var (Pos.unmark var) format_typ typ)) + func_params (format_block p.decl_ctx) func_body)) p.scopes diff --git a/french_law/python/__init__.py b/french_law/python/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/french_law/python/catala_runtime.py b/french_law/python/catala_runtime.py index d54a7b21..61ba6567 100644 --- a/french_law/python/catala_runtime.py +++ b/french_law/python/catala_runtime.py @@ -13,6 +13,10 @@ from gmpy2 import log2, mpz, mpq, mpfr, mpc # type: ignore import datetime import dateutil.relativedelta # type: ignore from typing import NewType, List, Callable, Tuple, Optional, TypeVar, Iterable +from functools import reduce + +Alpha = TypeVar('Alpha') +Beta = TypeVar('Beta') # ===== # Types @@ -241,24 +245,38 @@ def duration_to_years_months_days(d: Duration) -> Tuple[int, int, int]: def duration_to_string(s: Duration) -> str: return "{}".format(s) +# ----- +# Lists +# ----- + + +def list_fold_left(f: Callable[[Alpha, Beta], Alpha], init: Alpha, l: List[Beta]) -> Alpha: + return reduce(f, l, init) + + +def list_filter(f: Callable[[Alpha], bool], l: List[Alpha]) -> List[Alpha]: + return [i for i in l if f(i)] + + +def list_map(f: Callable[[Alpha], Beta], l: List[Alpha]) -> List[Beta]: + return [f(i) for i in l] + + # ======== # Defaults # ======== -Alpha = TypeVar('Alpha') - - def handle_default( - exceptions: List[Callable[[], Alpha]], - just: Callable[[], Alpha], - cons: Callable[[], Alpha] + exceptions: List[Callable[[Unit], Alpha]], + just: Callable[[Unit], Alpha], + cons: Callable[[Unit], Alpha] ) -> Alpha: acc: Optional[Alpha] = None for exception in exceptions: new_val: Optional[Alpha] try: - new_val = exception() + new_val = exception(Unit()) except EmptyError: new_val = None if acc is None: @@ -268,8 +286,8 @@ def handle_default( elif not (acc is None) and not (new_val is None): raise ConflictError if acc is None: - if just(): - return cons() + if just(Unit()): + return cons(Unit()) else: raise EmptyError else: @@ -279,3 +297,23 @@ def handle_default( def no_input() -> Callable[[], Alpha]: # From https://stackoverflow.com/questions/8294618/define-a-lambda-expression-that-raises-an-exception return (_ for _ in ()).throw(EmptyError) + +# ======= +# Logging +# ======= + + +def log_variable_definition(headings: List[str], value: Alpha) -> Alpha: + return value + + +def log_begin_call(headings: List[str], f: Callable[[Alpha], Beta], value: Alpha) -> Beta: + return f(value) + + +def log_end_call(headings: List[str], value: Alpha) -> Alpha: + return value + + +def log_decision_taken(pos: SourcePosition, value: Alpha) -> Alpha: + return value