Fix monomorphization problems with [TAny] left

This commit is contained in:
Denis Merigoux 2024-01-17 16:03:20 +01:00
parent 0a8fdde7de
commit 5310e47e5b
No known key found for this signature in database
GPG Key ID: EE99DCFA365C3EE3
10 changed files with 100 additions and 39 deletions

View File

@ -140,6 +140,12 @@ module Content = struct
let add_suggestion (content : t) (suggestion : string list) = let add_suggestion (content : t) (suggestion : string list) =
content @ [Suggestion suggestion] content @ [Suggestion suggestion]
let add_position
(content : t)
?(message : message option = None)
(position : Pos.t) =
content @ [Position { pos = position; pos_message = message }]
let of_string (s : string) : t = let of_string (s : string) : t =
[MainMessage (fun ppf -> Format.pp_print_string ppf s)] [MainMessage (fun ppf -> Format.pp_print_string ppf s)]

View File

@ -50,6 +50,7 @@ module Content : sig
val to_internal_error : t -> t val to_internal_error : t -> t
val add_suggestion : t -> string list -> t val add_suggestion : t -> string list -> t
val add_position : t -> ?message:message option -> Pos.t -> t
(** {2 Content emission}*) (** {2 Content emission}*)

View File

@ -23,7 +23,7 @@ let expr ctx env e =
[Some] *) [Some] *)
(* Intermediate unboxings are fine since the [check_expr] will rebox in (* Intermediate unboxings are fine since the [check_expr] will rebox in
depth *) depth *)
Typing.check_expr ~leave_unresolved:false ctx ~env (Expr.unbox e) Typing.check_expr ~leave_unresolved:ErrorOnAny ctx ~env (Expr.unbox e)
let rule ctx env rule = let rule ctx env rule =
let env = let env =

View File

@ -192,7 +192,7 @@ module Passes = struct
match typed with match typed with
| Typed _ -> ( | Typed _ -> (
Message.emit_debug "Typechecking again..."; Message.emit_debug "Typechecking again...";
try Typing.program ~leave_unresolved:false prg try Typing.program ~leave_unresolved:ErrorOnAny prg
with Message.CompilerError error_content -> with Message.CompilerError error_content ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
Printexc.raise_with_backtrace Printexc.raise_with_backtrace
@ -239,14 +239,14 @@ module Passes = struct
| true, _, Untyped _ -> | true, _, Untyped _ ->
Program.untype Program.untype
(Lcalc.From_dcalc.translate_program_without_exceptions (Lcalc.From_dcalc.translate_program_without_exceptions
(Shared_ast.Typing.program ~leave_unresolved:false prg)) (Shared_ast.Typing.program ~leave_unresolved:ErrorOnAny prg))
| true, _, Typed _ -> | true, _, Typed _ ->
Lcalc.From_dcalc.translate_program_without_exceptions prg Lcalc.From_dcalc.translate_program_without_exceptions prg
| false, _, Typed _ -> | false, _, Typed _ ->
Program.untype (Lcalc.From_dcalc.translate_program_with_exceptions prg) Program.untype (Lcalc.From_dcalc.translate_program_with_exceptions prg)
| false, _, Untyped _ -> | false, _, Untyped _ ->
Lcalc.From_dcalc.translate_program_with_exceptions Lcalc.From_dcalc.translate_program_with_exceptions
(Shared_ast.Typing.program ~leave_unresolved:false prg) (Shared_ast.Typing.program ~leave_unresolved:ErrorOnAny prg)
| _, _, Custom _ -> invalid_arg "Driver.Passes.lcalc" | _, _, Custom _ -> invalid_arg "Driver.Passes.lcalc"
in in
let prg = let prg =
@ -271,12 +271,12 @@ module Passes = struct
prg) prg)
in in
Message.emit_debug "Retyping lambda calculus..."; Message.emit_debug "Retyping lambda calculus...";
let prg = Typing.program ~leave_unresolved:true prg in let prg = Typing.program ~leave_unresolved:LeaveAny prg in
let prg, type_ordering = let prg, type_ordering =
if monomorphize_types then ( if monomorphize_types then (
Message.emit_debug "Monomorphizing types..."; Message.emit_debug "Monomorphizing types...";
let prg, type_ordering = Lcalc.Monomorphize.program prg in let prg, type_ordering = Lcalc.Monomorphize.program prg in
let prg = Typing.program ~leave_unresolved:false prg in let prg = Typing.program ~leave_unresolved:ErrorOnAny prg in
prg, type_ordering) prg, type_ordering)
else prg, type_ordering else prg, type_ordering
in in
@ -299,7 +299,7 @@ module Passes = struct
~avoid_exceptions ~closure_conversion ~monomorphize_types ~avoid_exceptions ~closure_conversion ~monomorphize_types
in in
Message.emit_debug "Retyping lambda calculus..."; Message.emit_debug "Retyping lambda calculus...";
let prg = Typing.program ~leave_unresolved:true prg in let prg = Typing.program ~leave_unresolved:LeaveAny prg in
debug_pass_name "scalc"; debug_pass_name "scalc";
( Scalc.From_lcalc.translate_program ( Scalc.From_lcalc.translate_program
~config:{ keep_special_ops; dead_value_assignment; no_struct_literals } ~config:{ keep_special_ops; dead_value_assignment; no_struct_literals }
@ -559,7 +559,8 @@ module Commands = struct
(* Additionally, we might want to check the invariants. *) (* Additionally, we might want to check the invariants. *)
if check_invariants then ( if check_invariants then (
let prg = let prg =
Shared_ast.Typing.program ~leave_unresolved:false (Program.untype prg) Shared_ast.Typing.program ~leave_unresolved:ErrorOnAny
(Program.untype prg)
in in
Message.emit_debug "Checking invariants..."; Message.emit_debug "Checking invariants...";
let result = Dcalc.Invariants.check_all_invariants prg in let result = Dcalc.Invariants.check_all_invariants prg in

View File

@ -51,16 +51,24 @@ let rec translate_typ (tau : typ) : typ =
end end
let rec translate_default let rec translate_default
(exceptions : 'm D.expr list) (exceptions : typed D.expr list)
(just : 'm D.expr) (just : typed D.expr)
(cons : 'm D.expr) (cons : typed D.expr)
(mark_default : 'm mark) : 'm A.expr boxed = (mark_default : typed mark) : typed A.expr boxed =
(* Since the program is well typed, all exceptions have as type [option 't] *) (* Since the program is well typed, all exceptions have as type [option 't] *)
let exceptions = List.map translate_expr exceptions in let exceptions = List.map translate_expr exceptions in
let pos = Expr.mark_pos mark_default in let pos = Expr.mark_pos mark_default in
let exceptions_and_cons_ty =
match mark_default with Typed { ty; _ } -> translate_typ ty
in
let exceptions = let exceptions =
Expr.eappop ~op:Op.HandleDefaultOpt Expr.eappop ~op:Op.HandleDefaultOpt
~tys:[TAny, pos; TAny, pos; TAny, pos] ~tys:
[
TArray exceptions_and_cons_ty, pos;
TArrow ([TLit TUnit, pos], (TLit TBool, pos)), pos;
TArrow ([TLit TUnit, pos], exceptions_and_cons_ty), pos;
]
~args: ~args:
[ [
Expr.earray exceptions mark_default; Expr.earray exceptions mark_default;
@ -75,7 +83,7 @@ let rec translate_default
in in
exceptions exceptions
and translate_expr (e : 'm D.expr) : 'm A.expr boxed = and translate_expr (e : typed D.expr) : typed A.expr boxed =
let mark = Mark.get e in let mark = Mark.get e in
match Mark.remove e with match Mark.remove e with
| EEmptyError -> | EEmptyError ->
@ -120,8 +128,9 @@ and translate_expr (e : 'm D.expr) : 'm A.expr boxed =
Expr.map ~f:translate_expr (Mark.add mark e) Expr.map ~f:translate_expr (Mark.add mark e)
| _ -> . | _ -> .
let translate_scope_body_expr (scope_body_expr : 'expr1 scope_body_expr) : let translate_scope_body_expr
'expr2 scope_body_expr Bindlib.box = (scope_body_expr : (dcalc, typed) gexpr scope_body_expr) :
(lcalc, typed) gexpr scope_body_expr Bindlib.box =
Scope.fold_right_lets Scope.fold_right_lets
~f:(fun scope_let var_next acc -> ~f:(fun scope_let var_next acc ->
Bindlib.box_apply2 Bindlib.box_apply2

View File

@ -50,8 +50,7 @@ let collect_monomorphized_instances (prg : typed program) :
let tuple_instances_counter = ref 0 in let tuple_instances_counter = ref 0 in
let rec collect_typ acc typ = let rec collect_typ acc typ =
match Mark.remove typ with match Mark.remove typ with
| TStruct _ | TEnum _ | TAny | TClosureEnv | TLit _ -> acc | TTuple args when List.for_all (fun t -> Mark.remove t <> TAny) args ->
| TTuple args ->
let new_acc = let new_acc =
{ {
acc with acc with
@ -83,7 +82,7 @@ let collect_monomorphized_instances (prg : typed program) :
| TArray t | TDefault t -> collect_typ acc t | TArray t | TDefault t -> collect_typ acc t
| TArrow (args, ret) -> | TArrow (args, ret) ->
List.fold_left collect_typ (collect_typ acc ret) args List.fold_left collect_typ (collect_typ acc ret) args
| TOption t -> | TOption t when Mark.remove t <> TAny ->
let new_acc = let new_acc =
{ {
acc with acc with
@ -114,10 +113,34 @@ let collect_monomorphized_instances (prg : typed program) :
} }
in in
collect_typ new_acc t collect_typ new_acc t
| TStruct _ | TEnum _ | TAny | TClosureEnv | TLit _ -> acc
| TOption _ | TTuple _ ->
raise
(Message.CompilerError
(Message.Content.add_position
(Message.Content.to_internal_error
(Message.Content.of_message (fun fmt ->
Format.fprintf fmt
"Some types in tuples or option have not been resolved \
by the typechecking before monomorphization.")))
(Mark.get typ)))
in in
let rec collect_expr acc e = let rec collect_expr acc e =
let acc = collect_typ acc (Expr.maybe_ty (Mark.get e)) in let acc = collect_typ acc (Expr.maybe_ty (Mark.get e)) in
Expr.shallow_fold (fun e acc -> collect_expr acc e) e acc match Mark.remove e with
| EAbs { binder; tys } ->
let acc = List.fold_left collect_typ acc tys in
let _, body = Bindlib.unmbind binder in
collect_expr acc body
| EApp { f; args; tys } ->
let acc = List.fold_left collect_typ acc tys in
let acc = List.fold_left collect_expr acc args in
collect_expr acc f
| EAppOp { op = _; args; tys } ->
let acc = List.fold_left collect_typ acc tys in
let acc = List.fold_left collect_expr acc args in
acc
| _ -> Expr.shallow_fold (fun e acc -> collect_expr acc e) e acc
in in
let acc = let acc =
Scope.fold_left Scope.fold_left
@ -229,7 +252,9 @@ let rec monomorphize_expr
| EInj { name; e = e1; cons } when EnumName.equal name Expr.option_enum -> | EInj { name; e = e1; cons } when EnumName.equal name Expr.option_enum ->
let option_instance = let option_instance =
TypMap.find TypMap.find
(Mark.remove (Expr.maybe_ty (Mark.get e1))) (match Mark.remove (Expr.maybe_ty (Mark.get e)) with
| TOption t -> Mark.remove t
| _ -> failwith "should not happen")
monomorphized_instances.options monomorphized_instances.options
in in
let new_e1 = monomorphize_expr monomorphized_instances e1 in let new_e1 = monomorphize_expr monomorphized_instances e1 in

View File

@ -67,11 +67,15 @@ type 'm program = {
let type_rule decl_ctx env = function let type_rule decl_ctx env = function
| Definition (loc, typ, io, expr) -> | Definition (loc, typ, io, expr) ->
let expr' = Typing.expr ~leave_unresolved:false decl_ctx ~env ~typ expr in let expr' =
Typing.expr ~leave_unresolved:ErrorOnAny decl_ctx ~env ~typ expr
in
Definition (loc, typ, io, Expr.unbox expr') Definition (loc, typ, io, Expr.unbox expr')
| Assertion expr -> | Assertion expr ->
let typ = Mark.add (Expr.pos expr) (TLit TBool) in let typ = Mark.add (Expr.pos expr) (TLit TBool) in
let expr' = Typing.expr ~leave_unresolved:false decl_ctx ~env ~typ expr in let expr' =
Typing.expr ~leave_unresolved:ErrorOnAny decl_ctx ~env ~typ expr
in
Assertion (Expr.unbox expr') Assertion (Expr.unbox expr')
| Call (sc_name, ssc_name, m) -> | Call (sc_name, ssc_name, m) ->
let pos = Expr.mark_pos m in let pos = Expr.mark_pos m in
@ -115,7 +119,8 @@ let type_program (type m) (prg : m program) : typed program =
TopdefName.Map.map TopdefName.Map.map
(fun (expr, typ) -> (fun (expr, typ) ->
( Expr.unbox ( Expr.unbox
(Typing.expr prg.program_ctx ~leave_unresolved:false ~env ~typ expr), (Typing.expr prg.program_ctx ~leave_unresolved:ErrorOnAny ~env ~typ
expr),
typ )) typ ))
prg.program_topdefs prg.program_topdefs
in in

View File

@ -880,9 +880,13 @@ let decl_ctx ?(debug = false) decl_ctx (fmt : Format.formatter) (ctx : decl_ctx)
: unit = : unit =
let { ctx_enums; ctx_structs; _ } = ctx in let { ctx_enums; ctx_structs; _ } = ctx in
Format.fprintf fmt "%a@.%a@.@." Format.fprintf fmt "%a@.%a@.@."
(EnumName.Map.format_bindings (enum ~debug decl_ctx)) (EnumName.Map.format_bindings
~pp_sep:(fun fmt () -> Format.fprintf fmt "@.")
(enum ~debug decl_ctx))
ctx_enums ctx_enums
(StructName.Map.format_bindings (struct_ ~debug decl_ctx)) (StructName.Map.format_bindings
~pp_sep:(fun fmt () -> Format.fprintf fmt "@.")
(struct_ ~debug decl_ctx))
ctx_structs ctx_structs
let scope let scope

View File

@ -20,6 +20,8 @@
open Catala_utils open Catala_utils
module A = Definitions module A = Definitions
type resolving_strategy = LeaveAny | ErrorOnAny
module Any = module Any =
Uid.Make Uid.Make
(struct (struct
@ -52,7 +54,8 @@ and naked_typ =
| TAny of Any.t | TAny of Any.t
| TClosureEnv | TClosureEnv
let rec typ_to_ast ~leave_unresolved (ty : unionfind_typ) : A.typ = let rec typ_to_ast ~(leave_unresolved : resolving_strategy) (ty : unionfind_typ)
: A.typ =
let typ_to_ast = typ_to_ast ~leave_unresolved in let typ_to_ast = typ_to_ast ~leave_unresolved in
let ty, pos = UnionFind.get (UnionFind.find ty) in let ty, pos = UnionFind.get (UnionFind.find ty) in
match ty with match ty with
@ -64,14 +67,15 @@ let rec typ_to_ast ~leave_unresolved (ty : unionfind_typ) : A.typ =
| TArrow (t1, t2) -> A.TArrow (List.map typ_to_ast t1, typ_to_ast t2), pos | TArrow (t1, t2) -> A.TArrow (List.map typ_to_ast t1, typ_to_ast t2), pos
| TArray t1 -> A.TArray (typ_to_ast t1), pos | TArray t1 -> A.TArray (typ_to_ast t1), pos
| TDefault t1 -> A.TDefault (typ_to_ast t1), pos | TDefault t1 -> A.TDefault (typ_to_ast t1), pos
| TAny _ -> | TAny _ -> (
if leave_unresolved then A.TAny, pos match leave_unresolved with
else | LeaveAny -> A.TAny, pos
| ErrorOnAny ->
(* No polymorphism in Catala: type inference should return full types (* No polymorphism in Catala: type inference should return full types
without wildcards, and this function is used to recover the types after without wildcards, and this function is used to recover the types after
typing. *) typing. *)
Message.raise_spanned_error pos Message.raise_spanned_error pos
"Internal error: typing at this point could not be resolved" "Internal error: typing at this point could not be resolved")
| TClosureEnv -> TClosureEnv, pos | TClosureEnv -> TClosureEnv, pos
let rec ast_to_typ (ty : A.typ) : unionfind_typ = let rec ast_to_typ (ty : A.typ) : unionfind_typ =
@ -419,7 +423,7 @@ let ty : (_, unionfind_typ A.custom) A.marked -> unionfind_typ =
(** Infers the most permissive type from an expression *) (** Infers the most permissive type from an expression *)
let rec typecheck_expr_bottom_up : let rec typecheck_expr_bottom_up :
type a m. type a m.
leave_unresolved:bool -> leave_unresolved:resolving_strategy ->
A.decl_ctx -> A.decl_ctx ->
(a, m) A.gexpr Env.t -> (a, m) A.gexpr Env.t ->
(a, m) A.gexpr -> (a, m) A.gexpr ->
@ -432,13 +436,15 @@ let rec typecheck_expr_bottom_up :
(** Checks whether the expression can be typed with the provided type *) (** Checks whether the expression can be typed with the provided type *)
and typecheck_expr_top_down : and typecheck_expr_top_down :
type a m. type a m.
leave_unresolved:bool -> leave_unresolved:resolving_strategy ->
A.decl_ctx -> A.decl_ctx ->
(a, m) A.gexpr Env.t -> (a, m) A.gexpr Env.t ->
unionfind_typ -> unionfind_typ ->
(a, m) A.gexpr -> (a, m) A.gexpr ->
(a, unionfind_typ A.custom) A.boxed_gexpr = (a, unionfind_typ A.custom) A.boxed_gexpr =
fun ~leave_unresolved ctx env tau e -> fun ~leave_unresolved ctx env tau e ->
(* Message.emit_debug "Propagating type %a for naked_expr :@.@[<hov 2>%a@]"
(format_typ ctx) tau Expr.format e; *)
let pos_e = Expr.pos e in let pos_e = Expr.pos e in
let () = let () =
(* If there already is a type annotation on the given expr, ensure it (* If there already is a type annotation on the given expr, ensure it
@ -919,7 +925,7 @@ let get_ty_mark ~leave_unresolved (A.Custom { A.custom = uf; pos }) =
let expr_raw let expr_raw
(type a) (type a)
~(leave_unresolved : bool) ~(leave_unresolved : resolving_strategy)
(ctx : A.decl_ctx) (ctx : A.decl_ctx)
?(env = Env.empty ctx) ?(env = Env.empty ctx)
?(typ : A.typ option) ?(typ : A.typ option)

View File

@ -42,11 +42,13 @@ end
(** In the following functions, the [~leave_unresolved] labeled parameter (** In the following functions, the [~leave_unresolved] labeled parameter
controls the behavior of the typer in the case where polymorphic expressions controls the behavior of the typer in the case where polymorphic expressions
are still found after typing: if set to [true], it allows them (giving them are still found after typing: if set to [LeaveAny], it allows them (giving
[TAny] and losing typing information), if set to [false], it aborts. *) them [TAny] and losing typing information); if set to [ErrorOnAny], it
aborts. *)
type resolving_strategy = LeaveAny | ErrorOnAny
val expr : val expr :
leave_unresolved:bool -> leave_unresolved:resolving_strategy ->
decl_ctx -> decl_ctx ->
?env:'e Env.t -> ?env:'e Env.t ->
?typ:typ -> ?typ:typ ->
@ -77,7 +79,7 @@ val expr :
didn't cause problems) *) didn't cause problems) *)
val check_expr : val check_expr :
leave_unresolved:bool -> leave_unresolved:resolving_strategy ->
decl_ctx -> decl_ctx ->
?env:'e Env.t -> ?env:'e Env.t ->
?typ:typ -> ?typ:typ ->
@ -89,7 +91,9 @@ val check_expr :
information, e.g. any [TAny] appearing in the AST is replaced) *) information, e.g. any [TAny] appearing in the AST is replaced) *)
val program : val program :
leave_unresolved:bool -> ('a, 'm) gexpr program -> ('a, typed) gexpr program leave_unresolved:resolving_strategy ->
('a, 'm) gexpr program ->
('a, typed) gexpr program
(** Typing on whole programs (as defined in Shared_ast.program, i.e. for the (** Typing on whole programs (as defined in Shared_ast.program, i.e. for the
later dcalc/lcalc stages. later dcalc/lcalc stages.