diff --git a/src/catala/surface/desugaring.ml b/src/catala/surface/desugaring.ml index dec57791..cc232e05 100644 --- a/src/catala/surface/desugaring.ml +++ b/src/catala/surface/desugaring.ml @@ -728,70 +728,106 @@ and disambiguate_match_and_build_expression (scope : Scopelang.Ast.ScopeName.t) (ctxt : Name_resolution.context) (cases : Ast.match_case Pos.marked list) : Scopelang.Ast.expr Pos.marked Bindlib.box Scopelang.Ast.EnumConstructorMap.t * Scopelang.Ast.EnumName.t = - let prev_e_uid = ref None in + let manage_match_cases (cases_d, e_uid) (case, _pos_case) = + match case with + | Ast.MatchCase case -> + let constructor, binding = Pos.unmark case.Ast.match_case_pattern in + let e_uid', c_uid = + disambiguate_constructor ctxt constructor (Pos.get_position case.Ast.match_case_pattern) + in + let e_uid = + match e_uid with + | None -> e_uid' + | Some e_uid -> + if e_uid = e_uid' then e_uid + else + Errors.raise_spanned_error + (Format.asprintf + "This case matches a constructor of enumeration %a but previous case were \ + matching constructors of enumeration %a" + Scopelang.Ast.EnumName.format_t e_uid Scopelang.Ast.EnumName.format_t e_uid') + (Pos.get_position case.Ast.match_case_pattern) + in + (match Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d with + | None -> () + | Some e_case -> + Errors.raise_multispanned_error + (Format.asprintf "The constructor %a has been matched twice:" + Scopelang.Ast.EnumConstructor.format_t c_uid) + [ + (None, Pos.get_position case.match_case_expr); + (None, Pos.get_position (Bindlib.unbox e_case)); + ]); + let ctxt, (param_var, param_pos) = + match binding with + | None -> (ctxt, (Scopelang.Ast.Var.make ("_", Pos.no_pos), Pos.no_pos)) + | Some param -> + let ctxt, param_var = Name_resolution.add_def_local_var ctxt param in + (ctxt, (param_var, Pos.get_position param)) + in + let case_body = translate_expr scope ctxt case.Ast.match_case_expr in + let e_binder = Bindlib.bind_mvar (Array.of_list [ param_var ]) case_body in + let case_expr = + Bindlib.box_apply2 + (fun e_binder case_body -> + Pos.same_pos_as + (Scopelang.Ast.EAbs + ( (e_binder, param_pos), + [ + Scopelang.Ast.EnumConstructorMap.find c_uid + (Scopelang.Ast.EnumMap.find e_uid ctxt.Name_resolution.enums); + ] )) + case_body) + e_binder case_body + in + (Scopelang.Ast.EnumConstructorMap.add c_uid case_expr cases_d, Some e_uid) + | Ast.WildCard match_case_expr -> ( + match e_uid with + | None -> Errors.raise_error "Should not be the first case." + | Some e_uid -> + (* Gets all constructors of [e_uid]. *) + let constructors_map = Scopelang.Ast.EnumMap.find e_uid ctxt.Name_resolution.enums in + let missing_constructors = + Scopelang.Ast.EnumConstructorMap.filter_map + (fun c_uid _ -> + match Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d with + | Some _ -> None + | None -> Some c_uid) + constructors_map + in + + if Scopelang.Ast.EnumConstructorMap.is_empty missing_constructors then + failwith "Un reachable match case, all constructors are described." + else + (* Creates the [wildcard_payload] *) + let param = ("wildcard_payload", Pos.no_pos) in + let ctxt, (param_var, param_pos) = + let ctxt, param_var = Name_resolution.add_def_local_var ctxt param in + (ctxt, (param_var, Pos.get_position param)) + in + let case_body = translate_expr scope ctxt match_case_expr in + let e_binder = Bindlib.bind_mvar (Array.of_list [ param_var ]) case_body in + let bind_wildcard_payload c_uid _ (cases_d, e_uid_opt) = + let case_expr = + Bindlib.box_apply2 + (fun e_binder case_body -> + Pos.same_pos_as + (Scopelang.Ast.EAbs + ( (e_binder, param_pos), + [ + Scopelang.Ast.EnumConstructorMap.find c_uid + (Scopelang.Ast.EnumMap.find e_uid ctxt.Name_resolution.enums); + ] )) + case_body) + e_binder case_body + in + (Scopelang.Ast.EnumConstructorMap.add c_uid case_expr cases_d, e_uid_opt) + in + Scopelang.Ast.EnumConstructorMap.fold bind_wildcard_payload missing_constructors + (cases_d, Some e_uid)) + in let expr, e_name = - List.fold_left - (fun (cases_d, e_uid) (case, _pos_case) -> - match case with - | Ast.MatchCase case -> - let constructor, binding = Pos.unmark case.Ast.match_case_pattern in - let e_uid', c_uid = - disambiguate_constructor ctxt constructor - (Pos.get_position case.Ast.match_case_pattern) - in - let e_uid = - match e_uid with - | None -> e_uid' - | Some e_uid -> - if e_uid = e_uid' then e_uid - else - Errors.raise_spanned_error - (Format.asprintf - "This case matches a constructor of enumeration %a but previous case were \ - matching constructors of enumeration %a" - Scopelang.Ast.EnumName.format_t e_uid Scopelang.Ast.EnumName.format_t - e_uid') - (Pos.get_position case.Ast.match_case_pattern) - in - (match Scopelang.Ast.EnumConstructorMap.find_opt c_uid cases_d with - | None -> () - | Some e_case -> - Errors.raise_multispanned_error - (Format.asprintf "The constructor %a has been matched twice:" - Scopelang.Ast.EnumConstructor.format_t c_uid) - [ - (None, Pos.get_position case.match_case_expr); - (None, Pos.get_position (Bindlib.unbox e_case)); - ]); - let ctxt, (param_var, param_pos) = - match binding with - | None -> (ctxt, (Scopelang.Ast.Var.make ("_", Pos.no_pos), Pos.no_pos)) - | Some param -> - let ctxt, param_var = Name_resolution.add_def_local_var ctxt param in - (ctxt, (param_var, Pos.get_position param)) - in - let case_body = translate_expr scope ctxt case.Ast.match_case_expr in - let e_binder = Bindlib.bind_mvar (Array.of_list [ param_var ]) case_body in - let case_expr = - Bindlib.box_apply2 - (fun e_binder case_body -> - Pos.same_pos_as - (Scopelang.Ast.EAbs - ( (e_binder, param_pos), - [ - Scopelang.Ast.EnumConstructorMap.find c_uid - (Scopelang.Ast.EnumMap.find e_uid ctxt.Name_resolution.enums); - ] )) - case_body) - e_binder case_body - in - prev_e_uid := Some e_uid; - (Scopelang.Ast.EnumConstructorMap.add c_uid case_expr cases_d, Some e_uid) - | Ast.WildCard _ -> - if Option.is_none !prev_e_uid then Errors.raise_error "Should not be the first case." - else failwith "TODO: Manage wildcard.") - (Scopelang.Ast.EnumConstructorMap.empty, None) - cases + List.fold_left manage_match_cases (Scopelang.Ast.EnumConstructorMap.empty, None) cases in (expr, Option.get e_name)