diff --git a/compiler/lcalc/closure_conversion.ml b/compiler/lcalc/closure_conversion.ml index 1a3bb032..c20a5881 100644 --- a/compiler/lcalc/closure_conversion.ml +++ b/compiler/lcalc/closure_conversion.ml @@ -257,7 +257,7 @@ let rec transform_closures_expr : free_vars, build_closure ctx (Var.Map.bindings free_vars) body vars tys m | EAppOp { - op = ((HandleDefaultOpt | Fold | Map | Map2 | Filter | Reduce), _) as op; + op = ((HandleExceptions | Fold | Map | Map2 | Filter | Reduce), _) as op; tys; args; } -> @@ -534,12 +534,7 @@ let rec hoist_closures_expr : in ( collected_closures, Expr.eapp ~f:(Expr.eabs new_binder tys e1_pos) ~args:new_args ~tys m ) - | EAppOp - { - op = ((HandleDefaultOpt | Fold | Map | Filter | Reduce), _) as op; - tys; - args; - } -> + | EAppOp { op = ((Fold | Map | Filter | Reduce), _) as op; tys; args } -> (* Special case for some operators: its arguments closures thunks because if you want to extract it as a function you need these closures to preserve evaluation order, but backends that don't support closures will simply diff --git a/compiler/lcalc/from_dcalc.ml b/compiler/lcalc/from_dcalc.ml index 7f8ef858..f147a38e 100644 --- a/compiler/lcalc/from_dcalc.ml +++ b/compiler/lcalc/from_dcalc.ml @@ -60,26 +60,43 @@ let rec translate_default (* Since the program is well typed, all exceptions have as type [option 't] *) let pos = Expr.mark_pos mark_default in let exceptions = List.map translate_expr exceptions in - let exceptions_and_cons_ty = Expr.maybe_ty mark_default in - Expr.eappop - ~op:(Op.HandleDefaultOpt, Expr.pos cons) - ~tys: - [ - TArray exceptions_and_cons_ty, pos; - TArrow ([TLit TUnit, pos], (TLit TBool, pos)), pos; - TArrow ([TLit TUnit, pos], exceptions_and_cons_ty), pos; - ] - ~args: - [ - Expr.earray exceptions - (Expr.map_ty (fun ty -> TArray ty, pos) mark_default); - (* In call-by-value programming languages, as lcalc, arguments are - evalulated before calling the function. Since we don't want to - execute the justification and conclusion while before checking every - exceptions, we need to thunk them. *) - Expr.thunk_term (translate_expr just); - Expr.thunk_term (translate_expr cons); - ] + let ty_option = Expr.maybe_ty mark_default in + let ty_array = TArray ty_option, pos in + let ty_alpha = + match ty_option with + | TOption ty, _ -> ty + | (TAny, _) as ty -> ty + | _ -> assert false + in + let mark_alpha = Expr.with_ty mark_default ty_alpha in + Expr.ematch ~name:Expr.option_enum + ~e: + (Expr.eappop + ~op:(Op.HandleExceptions, Expr.pos cons) + ~tys:[ty_array] + ~args:[Expr.earray exceptions (Expr.with_ty mark_default ty_array)] + mark_default) + ~cases: + (EnumConstructor.Map.of_list + [ + (* Some x -> Some x *) + ( Expr.some_constr, + let x = Var.make "x" in + Expr.make_abs [| x |] + (Expr.einj ~name:Expr.option_enum ~cons:Expr.some_constr + ~e:(Expr.evar x mark_alpha) mark_default) + [ty_alpha] pos ); + (* None -> if just then cons else None *) + ( Expr.none_constr, + Expr.thunk_term + (Expr.eifthenelse (translate_expr just) (translate_expr cons) + (Expr.einj + ~e: + (Expr.elit LUnit + (Expr.with_ty mark_default (TLit TUnit, pos))) + ~cons:Expr.none_constr ~name:Expr.option_enum mark_default) + mark_default) ); + ]) mark_default and translate_expr (e : 'm D.expr) : 'm A.expr boxed = diff --git a/compiler/lcalc/to_ocaml.ml b/compiler/lcalc/to_ocaml.ml index a5690f7a..a4091dd7 100644 --- a/compiler/lcalc/to_ocaml.ml +++ b/compiler/lcalc/to_ocaml.ml @@ -409,21 +409,21 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : format_with_parens arg1 | EAppOp { op = Log _, _; args = [arg1]; _ } -> Format.fprintf fmt "%a" format_with_parens arg1 - | EAppOp - { - op = (HandleDefaultOpt as op), _; - args = (EArray excs, _) :: _ as args; - _; - } -> - let pos = List.map Expr.pos excs in - Format.fprintf fmt "@[%s@ [|%a|]@ %a@]" - (Print.operator_to_string op) - (Format.pp_print_list - ~pp_sep:(fun ppf () -> Format.fprintf ppf ";@ ") - format_pos) - pos - (Format.pp_print_list ~pp_sep:Format.pp_print_space format_with_parens) - args + (* | EAppOp + * { + * op = (HandleDefaultOpt as op), _; + * args = (EArray excs, _) :: _ as args; + * _; + * } -> + * let pos = List.map Expr.pos excs in + * Format.fprintf fmt "@[%s@ [|%a|]@ %a@]" + * (Print.operator_to_string op) + * (Format.pp_print_list + * ~pp_sep:(fun ppf () -> Format.fprintf ppf ";@ ") + * format_pos) + * pos + * (Format.pp_print_list ~pp_sep:Format.pp_print_space format_with_parens) + * args *) | EApp { f; args; _ } -> Format.fprintf fmt "@[%a@ %a@]" format_with_parens f (Format.pp_print_list @@ -443,6 +443,12 @@ let rec format_expr (ctx : decl_ctx) (fmt : Format.formatter) (e : 'm expr) : Format.fprintf ppf "%a@ " format_pos pos | Div_int_int | Div_rat_rat | Div_mon_mon | Div_mon_rat | Div_dur_dur -> Format.fprintf ppf "%a@ " format_pos (Expr.pos (List.nth args 1)) + | HandleExceptions -> + Format.fprintf ppf "[|@[%a@]|]@ " + (Format.pp_print_list + ~pp_sep:(fun ppf () -> Format.fprintf ppf ";@ ") + format_pos) + (List.map Expr.pos args) | _ -> ()) (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@ ") diff --git a/compiler/plugins/explain.ml b/compiler/plugins/explain.ml index dd18be61..afed754b 100644 --- a/compiler/plugins/explain.ml +++ b/compiler/plugins/explain.ml @@ -1085,7 +1085,7 @@ let expr_to_dot_label0 : | Reduce -> xlang () ~en:"reduce" ~fr:"réunion" | Filter -> xlang () ~en:"filter" ~fr:"filtre" | Fold -> xlang () ~en:"fold" ~fr:"pliage" - | HandleDefaultOpt -> "" + | HandleExceptions -> "" | ToClosureEnv -> "" | FromClosureEnv -> "" in diff --git a/compiler/scalc/ast.ml b/compiler/scalc/ast.ml index 48c2b2da..90c686d1 100644 --- a/compiler/scalc/ast.ml +++ b/compiler/scalc/ast.ml @@ -34,8 +34,7 @@ module VarName = () let dead_value = VarName.fresh ("dead_value", Pos.no_pos) -let handle_default = FuncName.fresh ("handle_default", Pos.no_pos) -let handle_default_opt = FuncName.fresh ("handle_default_opt", Pos.no_pos) +let handle_exceptions = FuncName.fresh ("handle_exceptions", Pos.no_pos) type operator = Shared_ast.lcalc Shared_ast.operator diff --git a/compiler/scalc/from_lcalc.ml b/compiler/scalc/from_lcalc.ml index ca4933eb..f8331f34 100644 --- a/compiler/scalc/from_lcalc.ml +++ b/compiler/scalc/from_lcalc.ml @@ -138,15 +138,15 @@ and translate_expr (ctxt : 'm ctxt) (expr : 'm L.expr) : RevBlock.t * A.expr = | ETupleAccess { e = e1; index; _ } -> let e1_stmts, new_e1 = translate_expr ctxt e1 in e1_stmts, (A.ETupleAccess { e1 = new_e1; index }, Expr.pos expr) - | EAppOp - { - op = Op.HandleDefaultOpt, _; - args = [_exceptions; _just; _cons]; - tys = _; - } - when ctxt.config.keep_special_ops -> - (* This should be translated as a statement *) - raise (NotAnExpr { needs_a_local_decl = true }) + (* | EAppOp + * { + * op = Op.HandleDefaultOpt, _; + * args = [_exceptions; _just; _cons]; + * tys = _; + * } + * when ctxt.config.keep_special_ops -> + * (\* This should be translated as a statement *\) + * raise (NotAnExpr { needs_a_local_decl = true }) *) | EAppOp { op; args; tys = _ } -> let args_stmts, new_args = translate_expr_list ctxt args in (* FIXME: what happens if [arg] is not a tuple but reduces to one ? *) @@ -274,60 +274,60 @@ and translate_statements (ctxt : 'm ctxt) (block_expr : 'm L.expr) : A.block = ~tail:[A.SAssert (Mark.remove new_e), Expr.pos block_expr] e_stmts | EFatalError err -> [SFatalError err, Expr.pos block_expr] - | EAppOp - { op = Op.HandleDefaultOpt, _; tys = _; args = [exceptions; just; cons] } - when ctxt.config.keep_special_ops -> - let exceptions = - match Mark.remove exceptions with - | EStruct { fields; _ } -> ( - let _, exceptions = - List.find - (fun (field, _) -> - String.equal (Mark.remove (StructField.get_info field)) "content") - (StructField.Map.bindings fields) - in - match Mark.remove exceptions with - | EArray exceptions -> exceptions - | _ -> failwith "should not happen") - | _ -> failwith "should not happen" - in - let just = unthunk just in - let cons = unthunk cons in - let exceptions_stmts, new_exceptions = - translate_expr_list ctxt exceptions - in - let just_stmts, new_just = translate_expr ctxt just in - let cons_stmts, new_cons = translate_expr ctxt cons in - RevBlock.rebuild exceptions_stmts - ~tail: - (RevBlock.rebuild just_stmts - ~tail: - [ - ( A.SSpecialOp - (OHandleDefaultOpt - { - exceptions = new_exceptions; - just = new_just; - cons = - RevBlock.rebuild cons_stmts - ~tail: - [ - ( (match ctxt.inside_definition_of with - | None -> A.SReturn (Mark.remove new_cons) - | Some x -> - A.SLocalDef - { - name = Mark.copy new_cons x; - expr = new_cons; - typ = - Expr.maybe_ty (Mark.get block_expr); - }), - Expr.pos block_expr ); - ]; - return_typ = Expr.maybe_ty (Mark.get block_expr); - }), - Expr.pos block_expr ); - ]) + (* | EAppOp + * { op = Op.HandleDefaultOpt, _; tys = _; args = [exceptions; just; cons] } + * when ctxt.config.keep_special_ops -> + * let exceptions = + * match Mark.remove exceptions with + * | EStruct { fields; _ } -> ( + * let _, exceptions = + * List.find + * (fun (field, _) -> + * String.equal (Mark.remove (StructField.get_info field)) "content") + * (StructField.Map.bindings fields) + * in + * match Mark.remove exceptions with + * | EArray exceptions -> exceptions + * | _ -> failwith "should not happen") + * | _ -> failwith "should not happen" + * in + * let just = unthunk just in + * let cons = unthunk cons in + * let exceptions_stmts, new_exceptions = + * translate_expr_list ctxt exceptions + * in + * let just_stmts, new_just = translate_expr ctxt just in + * let cons_stmts, new_cons = translate_expr ctxt cons in + * RevBlock.rebuild exceptions_stmts + * ~tail: + * (RevBlock.rebuild just_stmts + * ~tail: + * [ + * ( A.SSpecialOp + * (OHandleDefaultOpt + * { + * exceptions = new_exceptions; + * just = new_just; + * cons = + * RevBlock.rebuild cons_stmts + * ~tail: + * [ + * ( (match ctxt.inside_definition_of with + * | None -> A.SReturn (Mark.remove new_cons) + * | Some x -> + * A.SLocalDef + * { + * name = Mark.copy new_cons x; + * expr = new_cons; + * typ = + * Expr.maybe_ty (Mark.get block_expr); + * }), + * Expr.pos block_expr ); + * ]; + * return_typ = Expr.maybe_ty (Mark.get block_expr); + * }), + * Expr.pos block_expr ); + * ]) *) | EApp { f = EAbs { binder; tys }, binder_mark; args; _ } -> (* This defines multiple local variables at the time *) let binder_pos = Expr.mark_pos binder_mark in diff --git a/compiler/scalc/to_c.ml b/compiler/scalc/to_c.ml index b71047fe..72e18ad5 100644 --- a/compiler/scalc/to_c.ml +++ b/compiler/scalc/to_c.ml @@ -313,7 +313,7 @@ let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit = | Reduce -> Format.pp_print_string fmt "catala_list_reduce" | Filter -> Format.pp_print_string fmt "catala_list_filter" | Fold -> Format.pp_print_string fmt "catala_list_fold_left" - | HandleDefaultOpt | FromClosureEnv | ToClosureEnv | Map2 -> + | HandleExceptions | FromClosureEnv | ToClosureEnv | Map2 -> failwith "unimplemented" let _format_string_list (fmt : Format.formatter) (uids : string list) : unit = @@ -367,8 +367,8 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : Format.fprintf fmt "%a %a" format_op op (format_expression ctx) arg1 | EAppOp { op; args = [arg1] } -> Format.fprintf fmt "%a(%a)" format_op op (format_expression ctx) arg1 - | EAppOp { op = HandleDefaultOpt, _; args = _ } -> - failwith "should not happen because of keep_special_ops" + (* | EAppOp { op = HandleDefaultOpt, _; args = _ } -> + * failwith "should not happen because of keep_special_ops" *) | EApp { f; args } -> Format.fprintf fmt "%a(@[%a)@]" (format_expression ctx) f (Format.pp_print_list diff --git a/compiler/scalc/to_python.ml b/compiler/scalc/to_python.ml index d3b7a79d..6fb80eb8 100644 --- a/compiler/scalc/to_python.ml +++ b/compiler/scalc/to_python.ml @@ -88,7 +88,7 @@ let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit = | Reduce -> Format.pp_print_string fmt "list_reduce" | Filter -> Format.pp_print_string fmt "list_filter" | Fold -> Format.pp_print_string fmt "list_fold_left" - | HandleDefaultOpt -> Format.pp_print_string fmt "handle_default_opt" + | HandleExceptions -> Format.pp_print_string fmt "handle_exceptions" | FromClosureEnv | ToClosureEnv -> failwith "unimplemented" let format_uid_list (fmt : Format.formatter) (uids : Uid.MarkedString.info list) @@ -348,27 +348,25 @@ let rec format_expression ctx (fmt : Format.formatter) (e : expr) : unit = Format.fprintf fmt "%a %a" format_op op (format_expression ctx) arg1 | EAppOp { op; args = [arg1] } -> Format.fprintf fmt "%a(%a)" format_op op (format_expression ctx) arg1 - | EAppOp { op = (HandleDefaultOpt, _) as op; args } -> - let pos = Mark.get e in - Format.fprintf fmt - "%a(@[SourcePosition(filename=\"%s\",@ start_line=%d,@ \ - start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)@]" - format_op op (Pos.get_file pos) (Pos.get_start_line pos) - (Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos) - format_string_list (Pos.get_law_info pos) + (* | EAppOp { op = ((HandleDefaultOpt), _) as op; args } -> + * let pos = Mark.get e in + * Format.fprintf fmt + * "%a(@[SourcePosition(filename=\"%s\",@ start_line=%d,@ \ + * start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)@]" + * format_op op (Pos.get_file pos) (Pos.get_start_line pos) + * (Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos) + * format_string_list (Pos.get_law_info pos) + * (Format.pp_print_list + * ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") + * (format_expression ctx)) + * args *) + | EApp { f = EFunc x, _; args = [(EArray el, _)] as args } + when Ast.FuncName.compare x Ast.handle_exceptions = 0 -> + Format.fprintf fmt "%a([%a], %a)@]" format_func_name x (Format.pp_print_list - ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") - (format_expression ctx)) - args - | EApp { f = EFunc x, pos; args } - when Ast.FuncName.compare x Ast.handle_default = 0 - || Ast.FuncName.compare x Ast.handle_default_opt = 0 -> - Format.fprintf fmt - "%a(@[SourcePosition(filename=\"%s\",@ start_line=%d,@ \ - start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)@]" - format_func_name x (Pos.get_file pos) (Pos.get_start_line pos) - (Pos.get_start_column pos) (Pos.get_end_line pos) (Pos.get_end_column pos) - format_string_list (Pos.get_law_info pos) + ~pp_sep:(fun ppf () -> Format.fprintf ppf ",@ ") + format_position) + (List.map Mark.get el) (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ",@ ") (format_expression ctx)) diff --git a/compiler/scalc/to_r.ml b/compiler/scalc/to_r.ml index b8352b7a..ab3aa373 100644 --- a/compiler/scalc/to_r.ml +++ b/compiler/scalc/to_r.ml @@ -103,7 +103,7 @@ let format_op (fmt : Format.formatter) (op : operator Mark.pos) : unit = | Reduce -> Format.pp_print_string fmt "catala_list_reduce" | Filter -> Format.pp_print_string fmt "catala_list_filter" | Fold -> Format.pp_print_string fmt "catala_list_fold_left" - | HandleDefaultOpt | FromClosureEnv | ToClosureEnv -> failwith "unimplemented" + | HandleExceptions | FromClosureEnv | ToClosureEnv -> failwith "unimplemented" let format_string_list (fmt : Format.formatter) (uids : string list) : unit = let sanitize_quotes = Re.compile (Re.char '"') in @@ -320,7 +320,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : Format.fprintf fmt "%a %a" format_op op (format_expression ctx) arg1 | EAppOp { op; args = [arg1] } -> Format.fprintf fmt "%a(%a)" format_op op (format_expression ctx) arg1 - | EAppOp { op = HandleDefaultOpt, _; _ } -> + | EAppOp { op = HandleExceptions, _; _ } -> Message.error ~internal:true "R compilation does not currently support the avoiding of exceptions" (* TODO: port the following to avoid-exceptions @@ -337,8 +337,7 @@ let rec format_expression (ctx : decl_ctx) (fmt : Format.formatter) (e : expr) : * (format_expression ctx)) * args *) | EApp { f = EFunc x, pos; args } - when Ast.FuncName.compare x Ast.handle_default = 0 - || Ast.FuncName.compare x Ast.handle_default_opt = 0 -> + when Ast.FuncName.compare x Ast.handle_exceptions = 0 -> Format.fprintf fmt "%a(@[catala_position(filename=\"%s\",@ start_line=%d,@ \ start_column=%d,@ end_line=%d, end_column=%d,@ law_headings=%a), %a)@]" diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 5af7e43a..0d2e6bf9 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -372,7 +372,7 @@ module Op = struct (* * polymorphic *) | Reduce : < polymorphic ; .. > t | Fold : < polymorphic ; .. > t - | HandleDefaultOpt : < polymorphic ; .. > t + | HandleExceptions : < polymorphic ; .. > t end type 'a operator = 'a Op.t diff --git a/compiler/shared_ast/interpreter.ml b/compiler/shared_ast/interpreter.ml index bd4a6341..d0f62c6b 100644 --- a/compiler/shared_ast/interpreter.ml +++ b/compiler/shared_ast/interpreter.ml @@ -422,7 +422,7 @@ let rec evaluate_operator ELit (LBool (o_eq_dat_dat x y)) | Eq_dur_dur, [(ELit (LDuration x), _); (ELit (LDuration y), _)] -> ELit (LBool (o_eq_dur_dur (rpos ()) x y)) - | HandleDefaultOpt, [(EArray exps, _); justification; conclusion] -> ( + | HandleExceptions, [(EArray exps, _)] -> ( let valid_exceptions = ListLabels.filter exps ~f:(function | EInj { name; cons; _ }, _ when EnumName.equal name Expr.option_enum -> @@ -430,28 +430,9 @@ let rec evaluate_operator | _ -> err ()) in match valid_exceptions with - | [] -> ( - let e = evaluate_expr (Expr.unthunk_term_nobox justification) in - match Mark.remove e with - | ELit (LBool true) -> - Mark.remove (evaluate_expr (Expr.unthunk_term_nobox conclusion)) - | ELit (LBool false) -> - EInj - { - name = Expr.option_enum; - cons = Expr.none_constr; - e = Mark.copy justification (ELit LUnit); - } - | EInj { name; cons; e } - when EnumName.equal name Expr.option_enum - && EnumConstructor.equal cons Expr.none_constr -> - EInj - { - name = Expr.option_enum; - cons = Expr.none_constr; - e = Mark.copy e (ELit LUnit); - } - | _ -> err ()) + | [] -> + EInj + { name = Expr.option_enum; cons = Expr.none_constr; e = ELit LUnit, m } | [((EInj { cons; name; _ } as e), _)] when EnumName.equal name Expr.option_enum && EnumConstructor.equal cons Expr.some_constr -> @@ -472,7 +453,7 @@ let rec evaluate_operator | Lte_mon_mon | Lte_dat_dat | Lte_dur_dur | Gt_int_int | Gt_rat_rat | Gt_mon_mon | Gt_dat_dat | Gt_dur_dur | Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dat_dat | Gte_dur_dur | Eq_int_int | Eq_rat_rat - | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur | HandleDefaultOpt ), + | Eq_mon_mon | Eq_dat_dat | Eq_dur_dur | HandleExceptions ), _ ) -> err () diff --git a/compiler/shared_ast/operator.ml b/compiler/shared_ast/operator.ml index 0ce7d4fa..d7bbb8e4 100644 --- a/compiler/shared_ast/operator.ml +++ b/compiler/shared_ast/operator.ml @@ -108,7 +108,7 @@ let name : type a. a t -> string = function | Eq_dur_dur -> "o_eq_dur_dur" | Eq_dat_dat -> "o_eq_dat_dat" | Fold -> "o_fold" - | HandleDefaultOpt -> "o_handledefaultopt" + | HandleExceptions -> "handle_exceptions" | ToClosureEnv -> "o_toclosureenv" | FromClosureEnv -> "o_fromclosureenv" @@ -231,7 +231,7 @@ let compare (type a1 a2) (t1 : a1 t) (t2 : a2 t) = | Eq_dat_dat, Eq_dat_dat | Eq_dur_dur, Eq_dur_dur | Fold, Fold - | HandleDefaultOpt, HandleDefaultOpt + | HandleExceptions, HandleExceptions | FromClosureEnv, FromClosureEnv | ToClosureEnv, ToClosureEnv -> 0 | Not, _ -> -1 | _, Not -> 1 | Length, _ -> -1 | _, Length -> 1 @@ -316,7 +316,7 @@ let compare (type a1 a2) (t1 : a1 t) (t2 : a2 t) = | Eq_mon_mon, _ -> -1 | _, Eq_mon_mon -> 1 | Eq_dat_dat, _ -> -1 | _, Eq_dat_dat -> 1 | Eq_dur_dur, _ -> -1 | _, Eq_dur_dur -> 1 - | HandleDefaultOpt, _ -> -1 | _, HandleDefaultOpt -> 1 + | HandleExceptions, _ -> -1 | _, HandleExceptions -> 1 | FromClosureEnv, _ -> -1 | _, FromClosureEnv -> 1 | ToClosureEnv, _ -> -1 | _, ToClosureEnv -> 1 | Fold, _ | _, Fold -> . @@ -341,7 +341,7 @@ let kind_dispatch : _ ) as op -> monomorphic op | ( ( Log _ | Length | Eq | Map | Map2 | Concat | Filter | Reduce | Fold - | HandleDefaultOpt | FromClosureEnv | ToClosureEnv ), + | HandleExceptions | FromClosureEnv | ToClosureEnv ), _ ) as op -> polymorphic op | ( ( Minus | ToRat | ToMoney | Round | Add | Sub | Mult | Div | Lt | Lte | Gt @@ -374,7 +374,7 @@ type 'a no_overloads = let translate (t : 'a no_overloads t Mark.pos) : 'b no_overloads t Mark.pos = match t with | ( ( Not | GetDay | GetMonth | GetYear | FirstDayOfMonth | LastDayOfMonth - | And | Or | Xor | HandleDefaultOpt | Log _ | Length | Eq | Map | Map2 + | And | Or | Xor | HandleExceptions | Log _ | Length | Eq | Map | Map2 | Concat | Filter | Reduce | Fold | Minus_int | Minus_rat | Minus_mon | Minus_dur | ToRat_int | ToRat_mon | ToMoney_rat | Round_rat | Round_mon | Add_int_int | Add_rat_rat | Add_mon_mon | Add_dat_dur _ | Add_dur_dur diff --git a/compiler/shared_ast/print.ml b/compiler/shared_ast/print.ml index 5b369006..3141b79d 100644 --- a/compiler/shared_ast/print.ml +++ b/compiler/shared_ast/print.ml @@ -280,7 +280,7 @@ let operator_to_string : type a. a Op.t -> string = | Eq_dur_dur -> "=^" | Eq_dat_dat -> "=@" | Fold -> "fold" - | HandleDefaultOpt -> "handle_default_opt" + | HandleExceptions -> "handle_exceptions" | ToClosureEnv -> "to_closure_env" | FromClosureEnv -> "from_closure_env" @@ -324,7 +324,7 @@ let operator_to_shorter_string : type a. a Op.t -> string = | Gte_int_int | Gte_rat_rat | Gte_mon_mon | Gte_dur_dur | Gte_dat_dat | Gte -> ">=" | Fold -> "fold" - | HandleDefaultOpt -> "handle_default_opt" + | HandleExceptions -> "handle_exceptions" | ToClosureEnv -> "to_closure_env" | FromClosureEnv -> "from_closure_env" @@ -400,7 +400,7 @@ module Precedence = struct | Div | Div_int_int | Div_rat_rat | Div_mon_rat | Div_mon_mon | Div_dur_dur -> Op Div - | HandleDefaultOpt | Map | Map2 | Concat | Filter | Reduce | Fold + | HandleExceptions | Map | Map2 | Concat | Filter | Reduce | Fold | ToClosureEnv | FromClosureEnv -> App) | EApp _ -> App @@ -865,13 +865,12 @@ let enum fmt (pp_name : Format.formatter -> unit) (c : typ EnumConstructor.Map.t) = - Format.fprintf fmt "@[%a %t %a@ %a@]" keyword "type" pp_name punctuation - "=" - (EnumConstructor.Map.format_bindings - ~pp_sep:(fun _ _ -> ()) + Format.fprintf fmt "@[%a %t %a@ %a@]@," keyword "type" pp_name + punctuation "=" + (EnumConstructor.Map.format_bindings ~pp_sep:Format.pp_print_space (fun fmt pp_n ty -> - Format.fprintf fmt "@[ %a %t %a %a@]@," punctuation "|" pp_n - keyword "of" + Format.fprintf fmt "@[%a %t %a %a@]" punctuation "|" pp_n keyword + "of" (if debug then typ_debug else typ decl_ctx) ty)) c diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index 554a4039..8302822b 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -294,7 +294,6 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) : let any2 = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in let any3 = lazy (UnionFind.make (TAny (Any.fresh ()), pos)) in let bt = lazy (UnionFind.make (TLit TBool, pos)) in - let ut = lazy (UnionFind.make (TLit TUnit, pos)) in let it = lazy (UnionFind.make (TLit TInt, pos)) in let cet = lazy (UnionFind.make (TClosureEnv, pos)) in let array a = lazy (UnionFind.make (TArray (Lazy.force a), pos)) in @@ -314,8 +313,7 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) : | Log (PosRecordIfTrueBool, _) -> [bt] @-> bt | Log _ -> [any] @-> any | Length -> [array any] @-> it - | HandleDefaultOpt -> - [array (option any); [ut] @-> bt; [ut] @-> option any] @-> option any + | HandleExceptions -> [array (option any)] @-> option any | ToClosureEnv -> [any] @-> cet | FromClosureEnv -> [cet] @-> any in @@ -347,7 +345,10 @@ let polymorphic_op_return_type | Log (PosRecordIfTrueBool, _), _ -> uf (TLit TBool) | Log _, [tau] -> tau | Length, _ -> uf (TLit TInt) - | HandleDefaultOpt, [_; _; tf] -> return_type tf 1 + | HandleExceptions, [tau] -> + let t_inner = any () in + unify ctx e tau (uf (TArray t_inner)); + t_inner | ToClosureEnv, _ -> uf TClosureEnv | FromClosureEnv, _ -> any () | _ -> Message.error ~pos "Mismatched operator arguments" diff --git a/runtimes/ocaml/runtime.ml b/runtimes/ocaml/runtime.ml index b8b8a5ea..011d5c04 100644 --- a/runtimes/ocaml/runtime.ml +++ b/runtimes/ocaml/runtime.ml @@ -716,11 +716,9 @@ module EventParser = struct ctx.events end -let handle_default_opt +let handle_exceptions (pos : source_position array) - (exceptions : 'a Eoption.t array) - (just : unit -> bool) - (cons : unit -> 'a Eoption.t) : 'a Eoption.t = + (exceptions : 'a Eoption.t array) : 'a Eoption.t = let len = Array.length exceptions in let rec filt_except i = if i < len then @@ -730,7 +728,7 @@ let handle_default_opt else [] in match filt_except 0 with - | [] -> if just () then cons () else Eoption.ENone () + | [] -> Eoption.ENone () | [(res, _)] -> res | res -> error Conflict (List.map (fun (_, i) -> pos.(i)) res) diff --git a/runtimes/ocaml/runtime.mli b/runtimes/ocaml/runtime.mli index abdb7d42..6301db54 100644 --- a/runtimes/ocaml/runtime.mli +++ b/runtimes/ocaml/runtime.mli @@ -335,12 +335,8 @@ val duration_to_string : duration -> string (**{1 Defaults} *) -val handle_default_opt : - source_position array -> - 'a Eoption.t array -> - (unit -> bool) -> - (unit -> 'a Eoption.t) -> - 'a Eoption.t +val handle_exceptions : + source_position array -> 'a Eoption.t array -> 'a Eoption.t (** @raise Error Conflict *) (**{1 Operators} *) diff --git a/runtimes/python/src/catala/runtime.py b/runtimes/python/src/catala/runtime.py index 30401772..d24ea08b 100644 --- a/runtimes/python/src/catala/runtime.py +++ b/runtimes/python/src/catala/runtime.py @@ -383,9 +383,9 @@ class NoValue(CatalaError): source_position) class Conflict(CatalaError): - def __init__(self, source_position: SourcePosition) -> None: - super().__init__("two or more concurring valid computations", - source_position) + def __init__(self, pos1: SourcePosition, pos2: SourcePosition) -> None: + super().__init__("two or more concurring valid computations:\nAt {}".format(pos2), + pos1) class DivisionByZero(CatalaError): def __init__(self, source_position: SourcePosition) -> None: @@ -606,28 +606,21 @@ def list_length(l: List[Alpha]) -> Integer: # ======== -def handle_default_opt( - pos: SourcePosition, - exceptions: List[Optional[Any]], - just: Callable[[Unit], bool], - cons: Callable[[Unit], Optional[Alpha]] -) -> Optional[Alpha]: +def handle_exceptions( + pos: List[SourcePosition], + exceptions: List[Optional[Alpha]]) +-> Optional[Alpha]: acc: Optional[Alpha] = None - for exception in exceptions: - if acc is None: - acc = exception - elif not (acc is None) and exception is None: + acc_pos: Optional[pos] = None + for exception, pos in zip(exceptions, pos): + if exception is None: pass # acc stays the same - elif not (acc is None) and not (exception is None): - raise Conflict(pos) - if acc is None: - b = just(Unit()) - if b: - return cons(Unit()) - else: - return None - else: - return acc + elif acc is None: + acc = exception + acc_pos = pos + else + raise Conflict(acc_pos,pos) + return acc def no_input() -> Callable[[Unit], Alpha]: