Typing: add a "assume operator types" mode

This allows for retyping after monomorphisation: a new function just extracts
the return type of the operator, without checking the operand types.

Also to avoid multiplying function arguments around the typer, the flags have
been gathered in a record that is included in the typing environment; it's ok to
give them default values as long as these are the strictest.
This commit is contained in:
Louis Gesbert 2024-02-07 17:41:04 +01:00
parent ae89c870c1
commit a56d95d790
5 changed files with 146 additions and 166 deletions

View File

@ -23,7 +23,7 @@ let expr ctx env e =
[Some] *) [Some] *)
(* Intermediate unboxings are fine since the [check_expr] will rebox in (* Intermediate unboxings are fine since the [check_expr] will rebox in
depth *) depth *)
Typing.check_expr ~leave_unresolved:ErrorOnAny ctx ~env (Expr.unbox e) Typing.check_expr ctx ~env (Expr.unbox e)
let rule ctx env rule = let rule ctx env rule =
let env = let env =

View File

@ -192,7 +192,7 @@ module Passes = struct
match typed with match typed with
| Typed _ -> ( | Typed _ -> (
Message.emit_debug "Typechecking again..."; Message.emit_debug "Typechecking again...";
try Typing.program ~leave_unresolved:ErrorOnAny prg try Typing.program prg
with Message.CompilerError error_content -> with Message.CompilerError error_content ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
Printexc.raise_with_backtrace Printexc.raise_with_backtrace
@ -257,7 +257,7 @@ module Passes = struct
let prg = let prg =
if not closure_conversion then ( if not closure_conversion then (
Message.emit_debug "Retyping lambda calculus..."; Message.emit_debug "Retyping lambda calculus...";
Typing.program ~leave_unresolved:LeaveAny prg) Typing.program ~fail_on_any:false prg)
else ( else (
Message.emit_debug "Performing closure conversion..."; Message.emit_debug "Performing closure conversion...";
let prg = Lcalc.Closure_conversion.closure_conversion prg in let prg = Lcalc.Closure_conversion.closure_conversion prg in
@ -268,16 +268,15 @@ module Passes = struct
else prg else prg
in in
Message.emit_debug "Retyping lambda calculus..."; Message.emit_debug "Retyping lambda calculus...";
Typing.program ~leave_unresolved:LeaveAny prg) Typing.program ~fail_on_any:false prg)
in in
let prg, type_ordering = let prg, type_ordering =
if monomorphize_types then ( if monomorphize_types then (
Message.emit_debug "Monomorphizing types..."; Message.emit_debug "Monomorphizing types...";
Lcalc.Monomorphize.program prg let prg, type_ordering = Lcalc.Monomorphize.program prg in
(* (* FIXME: typing no longer works after monomorphisation, it would Message.emit_debug "Retyping lambda calculus...";
* need special operator handling for arrays and options *) let prg = Typing.program ~fail_on_any:false ~assume_op_types:true prg in
* Message.emit_debug "Retyping lambda calculus..."; prg, type_ordering)
* let prg = Typing.program ~leave_unresolved:LeaveAny prg in *))
else prg, type_ordering else prg, type_ordering
in in
prg, type_ordering prg, type_ordering
@ -556,10 +555,7 @@ module Commands = struct
(* Additionally, we might want to check the invariants. *) (* Additionally, we might want to check the invariants. *)
if check_invariants then ( if check_invariants then (
let prg = let prg = Shared_ast.Typing.program prg in
Shared_ast.Typing.program ~leave_unresolved:ErrorOnAny
(Program.untype prg)
in
Message.emit_debug "Checking invariants..."; Message.emit_debug "Checking invariants...";
if Dcalc.Invariants.check_all_invariants prg then if Dcalc.Invariants.check_all_invariants prg then
Message.emit_result "All invariant checks passed" Message.emit_result "All invariant checks passed"

View File

@ -67,15 +67,11 @@ type 'm program = {
let type_rule decl_ctx env = function let type_rule decl_ctx env = function
| Definition (loc, typ, io, expr) -> | Definition (loc, typ, io, expr) ->
let expr' = let expr' = Typing.expr decl_ctx ~env ~typ expr in
Typing.expr ~leave_unresolved:ErrorOnAny decl_ctx ~env ~typ expr
in
Definition (loc, typ, io, Expr.unbox expr') Definition (loc, typ, io, Expr.unbox expr')
| Assertion expr -> | Assertion expr ->
let typ = Mark.add (Expr.pos expr) (TLit TBool) in let typ = Mark.add (Expr.pos expr) (TLit TBool) in
let expr' = let expr' = Typing.expr decl_ctx ~env ~typ expr in
Typing.expr ~leave_unresolved:ErrorOnAny decl_ctx ~env ~typ expr
in
Assertion (Expr.unbox expr') Assertion (Expr.unbox expr')
| Call (sc_name, ssc_name, m) -> | Call (sc_name, ssc_name, m) ->
let pos = Expr.mark_pos m in let pos = Expr.mark_pos m in
@ -118,10 +114,7 @@ let type_program (type m) (prg : m program) : typed program =
let program_topdefs = let program_topdefs =
TopdefName.Map.map TopdefName.Map.map
(fun (expr, typ) -> (fun (expr, typ) ->
( Expr.unbox Expr.unbox (Typing.expr prg.program_ctx ~env ~typ expr), typ)
(Typing.expr prg.program_ctx ~leave_unresolved:ErrorOnAny ~env ~typ
expr),
typ ))
prg.program_topdefs prg.program_topdefs
in in
let program_scopes = let program_scopes =

View File

@ -20,7 +20,7 @@
open Catala_utils open Catala_utils
module A = Definitions module A = Definitions
type resolving_strategy = LeaveAny | ErrorOnAny type flags = { fail_on_any : bool; assume_op_types : bool }
module Any = module Any =
Uid.Make Uid.Make
@ -54,9 +54,8 @@ and naked_typ =
| TAny of Any.t | TAny of Any.t
| TClosureEnv | TClosureEnv
let rec typ_to_ast ~(leave_unresolved : resolving_strategy) (ty : unionfind_typ) let rec typ_to_ast ~(flags : flags) (ty : unionfind_typ) : A.typ =
: A.typ = let typ_to_ast = typ_to_ast ~flags in
let typ_to_ast = typ_to_ast ~leave_unresolved in
let ty, pos = UnionFind.get (UnionFind.find ty) in let ty, pos = UnionFind.get (UnionFind.find ty) in
match ty with match ty with
| TLit l -> A.TLit l, pos | TLit l -> A.TLit l, pos
@ -67,15 +66,14 @@ let rec typ_to_ast ~(leave_unresolved : resolving_strategy) (ty : unionfind_typ)
| TArrow (t1, t2) -> A.TArrow (List.map typ_to_ast t1, typ_to_ast t2), pos | TArrow (t1, t2) -> A.TArrow (List.map typ_to_ast t1, typ_to_ast t2), pos
| TArray t1 -> A.TArray (typ_to_ast t1), pos | TArray t1 -> A.TArray (typ_to_ast t1), pos
| TDefault t1 -> A.TDefault (typ_to_ast t1), pos | TDefault t1 -> A.TDefault (typ_to_ast t1), pos
| TAny _ -> ( | TAny _ ->
match leave_unresolved with if flags.fail_on_any then
| LeaveAny -> A.TAny, pos
| ErrorOnAny ->
(* No polymorphism in Catala: type inference should return full types (* No polymorphism in Catala: type inference should return full types
without wildcards, and this function is used to recover the types after without wildcards, and this function is used to recover the types after
typing. *) typing. *)
Message.raise_spanned_error pos Message.raise_spanned_error pos
"Internal error: typing at this point could not be resolved") "Internal error: typing at this point could not be resolved"
else A.TAny, pos
| TClosureEnv -> TClosureEnv, pos | TClosureEnv -> TClosureEnv, pos
let rec ast_to_typ (ty : A.typ) : unionfind_typ = let rec ast_to_typ (ty : A.typ) : unionfind_typ =
@ -321,8 +319,39 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) :
in in
Lazy.force ty Lazy.force ty
(* Just returns the return type of the operator, assuming the operand types are
known. Less trict, but useful on monomorphised code where the operators no
longer have their standard types *)
let polymorphic_op_return_type
ctx
e
(op : Operator.polymorphic A.operator Mark.pos)
(targs : unionfind_typ list) : unionfind_typ =
let open Operator in
let pos = Mark.get op in
let uf t = UnionFind.make (t, pos) in
let any _ = uf (TAny (Any.fresh ())) in
let return_type tf arity =
let tret = any () in
unify ctx e tf (UnionFind.make (TArrow (List.init arity any, tret), pos));
tret
in
match Mark.remove op, targs with
| Fold, [_; tau; _] -> tau
| Eq, _ -> uf (TLit TBool)
| Map, [tf; _] -> uf (TArray (return_type tf 1))
| Map2, [tf; _; _] -> uf (TArray (return_type tf 2))
| (Filter | Reduce | Concat), [_; tau] -> tau
| Log (PosRecordIfTrueBool, _), _ -> uf (TLit TBool)
| Log _, [tau] -> tau
| Length, _ -> uf (TLit TInt)
| (HandleDefault | HandleDefaultOpt), [_; _; tf] -> return_type tf 1
| ToClosureEnv, _ -> uf TClosureEnv
| FromClosureEnv, _ -> any ()
| _ -> Message.raise_spanned_error pos "Mismatched operator arguments"
let resolve_overload_ret_type let resolve_overload_ret_type
~leave_unresolved ~flags
(ctx : A.decl_ctx) (ctx : A.decl_ctx)
e e
(op : Operator.overloaded A.operator) (op : Operator.overloaded A.operator)
@ -330,7 +359,7 @@ let resolve_overload_ret_type
let op_ty = let op_ty =
Operator.overload_type ctx Operator.overload_type ctx
(Mark.add (Expr.pos e) op) (Mark.add (Expr.pos e) op)
(List.map (typ_to_ast ~leave_unresolved) tys) (List.map (typ_to_ast ~flags) tys)
in in
ast_to_typ (Type.arrow_return op_ty) ast_to_typ (Type.arrow_return op_ty)
@ -338,6 +367,7 @@ let resolve_overload_ret_type
module Env = struct module Env = struct
type 'e t = { type 'e t = {
flags : flags;
structs : unionfind_typ A.StructField.Map.t A.StructName.Map.t; structs : unionfind_typ A.StructField.Map.t A.StructName.Map.t;
enums : unionfind_typ A.EnumConstructor.Map.t A.EnumName.Map.t; enums : unionfind_typ A.EnumConstructor.Map.t A.EnumName.Map.t;
vars : ('e, unionfind_typ) Var.Map.t; vars : ('e, unionfind_typ) Var.Map.t;
@ -347,10 +377,14 @@ module Env = struct
toplevel_vars : A.typ A.TopdefName.Map.t; toplevel_vars : A.typ A.TopdefName.Map.t;
} }
let empty (decl_ctx : A.decl_ctx) = let empty
?(fail_on_any = true)
?(assume_op_types = false)
(decl_ctx : A.decl_ctx) =
(* We fill the environment initially with the structs and enums (* We fill the environment initially with the structs and enums
declarations *) declarations *)
{ {
flags = { fail_on_any; assume_op_types };
structs = structs =
A.StructName.Map.map A.StructName.Map.map
(fun ty -> A.StructField.Map.map ast_to_typ ty) (fun ty -> A.StructField.Map.map ast_to_typ ty)
@ -423,29 +457,28 @@ let ty : (_, unionfind_typ A.custom) A.marked -> unionfind_typ =
(** Infers the most permissive type from an expression *) (** Infers the most permissive type from an expression *)
let rec typecheck_expr_bottom_up : let rec typecheck_expr_bottom_up :
type a m. type a m.
leave_unresolved:resolving_strategy ->
A.decl_ctx -> A.decl_ctx ->
(a, m) A.gexpr Env.t -> (a, m) A.gexpr Env.t ->
(a, m) A.gexpr -> (a, m) A.gexpr ->
(a, unionfind_typ A.custom) A.boxed_gexpr = (a, unionfind_typ A.custom) A.boxed_gexpr =
fun ~leave_unresolved ctx env e -> fun ctx env e ->
typecheck_expr_top_down ~leave_unresolved ctx env typecheck_expr_top_down ctx env
(UnionFind.make (add_pos e (TAny (Any.fresh ())))) (UnionFind.make (add_pos e (TAny (Any.fresh ()))))
e e
(** Checks whether the expression can be typed with the provided type *) (** Checks whether the expression can be typed with the provided type *)
and typecheck_expr_top_down : and typecheck_expr_top_down :
type a m. type a m.
leave_unresolved:resolving_strategy ->
A.decl_ctx -> A.decl_ctx ->
(a, m) A.gexpr Env.t -> (a, m) A.gexpr Env.t ->
unionfind_typ -> unionfind_typ ->
(a, m) A.gexpr -> (a, m) A.gexpr ->
(a, unionfind_typ A.custom) A.boxed_gexpr = (a, unionfind_typ A.custom) A.boxed_gexpr =
fun ~leave_unresolved ctx env tau e -> fun ctx env tau e ->
(* Message.emit_debug "Propagating type %a for naked_expr :@.@[<hov 2>%a@]" (* Message.emit_debug "Propagating type %a for naked_expr :@.@[<hov 2>%a@]"
(format_typ ctx) tau Expr.format e; *) (format_typ ctx) tau Expr.format e; *)
let pos_e = Expr.pos e in let pos_e = Expr.pos e in
let flags = env.flags in
let () = let () =
(* If there already is a type annotation on the given expr, ensure it (* If there already is a type annotation on the given expr, ensure it
matches *) matches *)
@ -519,7 +552,7 @@ and typecheck_expr_top_down :
A.StructField.Map.mapi A.StructField.Map.mapi
(fun f_name f_e -> (fun f_name f_e ->
let f_ty = A.StructField.Map.find f_name str in let f_ty = A.StructField.Map.find f_name str in
typecheck_expr_top_down ~leave_unresolved ctx env f_ty f_e) typecheck_expr_top_down ctx env f_ty f_e)
fields fields
in in
Expr.estruct ~name ~fields mark Expr.estruct ~name ~fields mark
@ -530,8 +563,7 @@ and typecheck_expr_top_down :
| None -> TAny (Any.fresh ()) | None -> TAny (Any.fresh ())
in in
let e_struct' = let e_struct' =
typecheck_expr_top_down ~leave_unresolved ctx env (unionfind t_struct) typecheck_expr_top_down ctx env (unionfind t_struct) e_struct
e_struct
in in
let name = let name =
match UnionFind.get (ty e_struct') with match UnionFind.get (ty e_struct') with
@ -598,8 +630,7 @@ and typecheck_expr_top_down :
in in
let mark = mark_with_tau_and_unify fld_ty in let mark = mark_with_tau_and_unify fld_ty in
let e_struct' = let e_struct' =
typecheck_expr_top_down ~leave_unresolved ctx env typecheck_expr_top_down ctx env (unionfind (TStruct name)) e_struct
(unionfind (TStruct name)) e_struct
in in
Expr.estructaccess ~e:e_struct' ~field ~name mark Expr.estructaccess ~e:e_struct' ~field ~name mark
| A.EInj { name; cons; e = e_enum } | A.EInj { name; cons; e = e_enum }
@ -607,23 +638,20 @@ and typecheck_expr_top_down :
if Definitions.EnumConstructor.equal cons Expr.some_constr then if Definitions.EnumConstructor.equal cons Expr.some_constr then
let cell_type = unionfind (TAny (Any.fresh ())) in let cell_type = unionfind (TAny (Any.fresh ())) in
let mark = mark_with_tau_and_unify (unionfind (TOption cell_type)) in let mark = mark_with_tau_and_unify (unionfind (TOption cell_type)) in
let e_enum' = let e_enum' = typecheck_expr_top_down ctx env cell_type e_enum in
typecheck_expr_top_down ~leave_unresolved ctx env cell_type e_enum
in
Expr.einj ~name ~cons ~e:e_enum' mark Expr.einj ~name ~cons ~e:e_enum' mark
else else
(* None constructor *) (* None constructor *)
let cell_type = unionfind (TAny (Any.fresh ())) in let cell_type = unionfind (TAny (Any.fresh ())) in
let mark = mark_with_tau_and_unify (unionfind (TOption cell_type)) in let mark = mark_with_tau_and_unify (unionfind (TOption cell_type)) in
let e_enum' = let e_enum' =
typecheck_expr_top_down ~leave_unresolved ctx env typecheck_expr_top_down ctx env (unionfind (TLit TUnit)) e_enum
(unionfind (TLit TUnit)) e_enum
in in
Expr.einj ~name ~cons ~e:e_enum' mark Expr.einj ~name ~cons ~e:e_enum' mark
| A.EInj { name; cons; e = e_enum } -> | A.EInj { name; cons; e = e_enum } ->
let mark = mark_with_tau_and_unify (unionfind (TEnum name)) in let mark = mark_with_tau_and_unify (unionfind (TEnum name)) in
let e_enum' = let e_enum' =
typecheck_expr_top_down ~leave_unresolved ctx env typecheck_expr_top_down ctx env
(A.EnumConstructor.Map.find cons (A.EnumName.Map.find name env.enums)) (A.EnumConstructor.Map.find cons (A.EnumName.Map.find name env.enums))
e_enum e_enum
in in
@ -640,14 +668,14 @@ and typecheck_expr_top_down :
in in
let t_ret = unionfind ~pos:e (TAny (Any.fresh ())) in let t_ret = unionfind ~pos:e (TAny (Any.fresh ())) in
let mark = mark_with_tau_and_unify t_ret in let mark = mark_with_tau_and_unify t_ret in
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env t_arg e1 in let e1' = typecheck_expr_top_down ctx env t_arg e1 in
let cases = let cases =
A.EnumConstructor.Map.merge A.EnumConstructor.Map.merge
(fun _ e e_ty -> (fun _ e e_ty ->
match e, e_ty with match e, e_ty with
| Some e, Some e_ty -> | Some e, Some e_ty ->
Some Some
(typecheck_expr_top_down ~leave_unresolved ctx env (typecheck_expr_top_down ctx env
(unionfind ~pos:e (TArrow ([e_ty], t_ret))) (unionfind ~pos:e (TArrow ([e_ty], t_ret)))
e) e)
| _ -> assert false) | _ -> assert false)
@ -658,10 +686,7 @@ and typecheck_expr_top_down :
let cases_ty = A.EnumName.Map.find name ctx.A.ctx_enums in let cases_ty = A.EnumName.Map.find name ctx.A.ctx_enums in
let t_ret = unionfind ~pos:e1 (TAny (Any.fresh ())) in let t_ret = unionfind ~pos:e1 (TAny (Any.fresh ())) in
let mark = mark_with_tau_and_unify t_ret in let mark = mark_with_tau_and_unify t_ret in
let e1' = let e1' = typecheck_expr_top_down ctx env (unionfind (TEnum name)) e1 in
typecheck_expr_top_down ~leave_unresolved ctx env (unionfind (TEnum name))
e1
in
let cases = let cases =
A.EnumConstructor.Map.mapi A.EnumConstructor.Map.mapi
(fun c_name e -> (fun c_name e ->
@ -670,7 +695,7 @@ and typecheck_expr_top_down :
there is a change to allow for multiple arguments, it might be there is a change to allow for multiple arguments, it might be
easier to use tuples directly. *) easier to use tuples directly. *)
let e_ty = unionfind ~pos:e (TArrow ([ast_to_typ c_ty], t_ret)) in let e_ty = unionfind ~pos:e (TArrow ([ast_to_typ c_ty], t_ret)) in
typecheck_expr_top_down ~leave_unresolved ctx env e_ty e) typecheck_expr_top_down ctx env e_ty e)
cases cases
in in
Expr.ematch ~e:e1' ~name ~cases mark Expr.ematch ~e:e1' ~name ~cases mark
@ -683,17 +708,15 @@ and typecheck_expr_top_down :
let args' = let args' =
A.ScopeVar.Map.mapi A.ScopeVar.Map.mapi
(fun name -> (fun name ->
typecheck_expr_top_down ~leave_unresolved ctx env typecheck_expr_top_down ctx env
(ast_to_typ (A.ScopeVar.Map.find name vars))) (ast_to_typ (A.ScopeVar.Map.find name vars)))
args args
in in
Expr.escopecall ~scope ~args:args' mark Expr.escopecall ~scope ~args:args' mark
| A.ERaise ex -> Expr.eraise ex context_mark | A.ERaise ex -> Expr.eraise ex context_mark
| A.ECatch { body; exn; handler } -> | A.ECatch { body; exn; handler } ->
let body' = typecheck_expr_top_down ~leave_unresolved ctx env tau body in let body' = typecheck_expr_top_down ctx env tau body in
let handler' = let handler' = typecheck_expr_top_down ctx env tau handler in
typecheck_expr_top_down ~leave_unresolved ctx env tau handler
in
Expr.ecatch body' exn handler' context_mark Expr.ecatch body' exn handler' context_mark
| A.EVar v -> | A.EVar v ->
let tau' = let tau' =
@ -732,9 +755,7 @@ and typecheck_expr_top_down :
| A.ETuple es -> | A.ETuple es ->
let tys = List.map (fun _ -> unionfind (TAny (Any.fresh ()))) es in let tys = List.map (fun _ -> unionfind (TAny (Any.fresh ()))) es in
let mark = mark_with_tau_and_unify (unionfind (TTuple tys)) in let mark = mark_with_tau_and_unify (unionfind (TTuple tys)) in
let es' = let es' = List.map2 (typecheck_expr_top_down ctx env) tys es in
List.map2 (typecheck_expr_top_down ~leave_unresolved ctx env) tys es
in
Expr.etuple es' mark Expr.etuple es' mark
| A.ETupleAccess { e = e1; index; size } -> | A.ETupleAccess { e = e1; index; size } ->
if index >= size then if index >= size then
@ -745,11 +766,7 @@ and typecheck_expr_top_down :
(List.init size (fun n -> (List.init size (fun n ->
if n = index then tau else unionfind ~pos:e1 (TAny (Any.fresh ())))) if n = index then tau else unionfind ~pos:e1 (TAny (Any.fresh ()))))
in in
let e1' = let e1' = typecheck_expr_top_down ctx env (unionfind ~pos:e1 tuple_ty) e1 in
typecheck_expr_top_down ~leave_unresolved ctx env
(unionfind ~pos:e1 tuple_ty)
e1
in
Expr.etupleaccess ~e:e1' ~index ~size context_mark Expr.etupleaccess ~e:e1' ~index ~size context_mark
| A.EAbs { binder; tys = t_args } -> | A.EAbs { binder; tys = t_args } ->
if Bindlib.mbinder_arity binder <> List.length t_args then if Bindlib.mbinder_arity binder <> List.length t_args then
@ -769,11 +786,9 @@ and typecheck_expr_top_down :
(fun env x tau_arg -> Env.add x tau_arg env) (fun env x tau_arg -> Env.add x tau_arg env)
env (Array.to_list xs) tau_args env (Array.to_list xs) tau_args
in in
let body' = let body' = typecheck_expr_top_down ctx env t_ret body in
typecheck_expr_top_down ~leave_unresolved ctx env t_ret body
in
let binder' = Bindlib.bind_mvar xs' (Expr.Box.lift body') in let binder' = Bindlib.bind_mvar xs' (Expr.Box.lift body') in
Expr.eabs binder' (List.map (typ_to_ast ~leave_unresolved) tau_args) mark Expr.eabs binder' (List.map (typ_to_ast ~flags) tau_args) mark
| A.EApp { f = e1; args; tys } -> | A.EApp { f = e1; args; tys } ->
(* Here we type the arguments first (in order), to ensure we know the types (* Here we type the arguments first (in order), to ensure we know the types
of the arguments if [f] is [EAbs] before disambiguation. This is also the of the arguments if [f] is [EAbs] before disambiguation. This is also the
@ -783,9 +798,7 @@ and typecheck_expr_top_down :
| [] -> List.map (fun _ -> unionfind (TAny (Any.fresh ()))) args | [] -> List.map (fun _ -> unionfind (TAny (Any.fresh ()))) args
| tys -> List.map ast_to_typ tys | tys -> List.map ast_to_typ tys
in in
let args' = let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
List.map2 (typecheck_expr_top_down ~leave_unresolved ctx env) t_args args
in
let t_args = let t_args =
match t_args, tys with match t_args, tys with
| [t], [] -> ( | [t], [] -> (
@ -805,9 +818,9 @@ and typecheck_expr_top_down :
t_args t_args
in in
let t_func = unionfind ~pos:e1 (TArrow (t_args, tau)) in let t_func = unionfind ~pos:e1 (TArrow (t_args, tau)) in
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env t_func e1 in let e1' = typecheck_expr_top_down ctx env t_func e1 in
Expr.eapp ~f:e1' ~args:args' Expr.eapp ~f:e1' ~args:args'
~tys:(List.map (typ_to_ast ~leave_unresolved) t_args) ~tys:(List.map (typ_to_ast ~flags) t_args)
context_mark context_mark
| A.EAppOp { op; tys; args } -> | A.EAppOp { op; tys; args } ->
let t_args = List.map ast_to_typ tys in let t_args = List.map ast_to_typ tys in
@ -818,87 +831,73 @@ and typecheck_expr_top_down :
(* Type the operator first, then right-to-left: polymorphic operators (* Type the operator first, then right-to-left: polymorphic operators
are required to allow the resolution of all type variables this are required to allow the resolution of all type variables this
way *) way *)
unify ctx e (polymorphic_op_type (Mark.add pos_e op)) t_func; if not env.flags.assume_op_types then
unify ctx e (polymorphic_op_type (Mark.add pos_e op)) t_func
else
unify ctx e
(polymorphic_op_return_type ctx e (Mark.add pos_e op) t_args)
tau;
List.rev_map2 List.rev_map2
(typecheck_expr_top_down ~leave_unresolved ctx env) (typecheck_expr_top_down ctx env)
(List.rev t_args) (List.rev args)) (List.rev t_args) (List.rev args))
~overloaded:(fun op -> ~overloaded:(fun op ->
(* Typing the arguments first is required to resolve the operator *) (* Typing the arguments first is required to resolve the operator *)
let args' = let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
List.map2 unify ctx e tau (resolve_overload_ret_type ~flags ctx e op t_args);
(typecheck_expr_top_down ~leave_unresolved ctx env)
t_args args
in
unify ctx e tau
(resolve_overload_ret_type ~leave_unresolved ctx e op t_args);
args') args')
~monomorphic:(fun op -> ~monomorphic:(fun op ->
(* Here it doesn't matter but may affect the error messages *) (* Here it doesn't matter but may affect the error messages *)
unify ctx e unify ctx e
(ast_to_typ (Operator.monomorphic_type (Mark.add pos_e op))) (ast_to_typ (Operator.monomorphic_type (Mark.add pos_e op)))
t_func; t_func;
List.map2 List.map2 (typecheck_expr_top_down ctx env) t_args args)
(typecheck_expr_top_down ~leave_unresolved ctx env)
t_args args)
~resolved:(fun op -> ~resolved:(fun op ->
(* This case should not fail *) (* This case should not fail *)
unify ctx e unify ctx e
(ast_to_typ (Operator.resolved_type (Mark.add pos_e op))) (ast_to_typ (Operator.resolved_type (Mark.add pos_e op)))
t_func; t_func;
List.map2 List.map2 (typecheck_expr_top_down ctx env) t_args args)
(typecheck_expr_top_down ~leave_unresolved ctx env)
t_args args)
in in
(* All operator applications are monomorphised at this point *) (* All operator applications are monomorphised at this point *)
let tys = List.map (typ_to_ast ~leave_unresolved) t_args in let tys = List.map (typ_to_ast ~flags) t_args in
Expr.eappop ~op ~args ~tys context_mark Expr.eappop ~op ~args ~tys context_mark
| A.EDefault { excepts; just; cons } -> | A.EDefault { excepts; just; cons } ->
let cons' = typecheck_expr_top_down ~leave_unresolved ctx env tau cons in let cons' = typecheck_expr_top_down ctx env tau cons in
let just' = let just' =
typecheck_expr_top_down ~leave_unresolved ctx env typecheck_expr_top_down ctx env (unionfind ~pos:just (TLit TBool)) just
(unionfind ~pos:just (TLit TBool))
just
in
let excepts' =
List.map (typecheck_expr_top_down ~leave_unresolved ctx env tau) excepts
in in
let excepts' = List.map (typecheck_expr_top_down ctx env tau) excepts in
Expr.edefault ~excepts:excepts' ~just:just' ~cons:cons' context_mark Expr.edefault ~excepts:excepts' ~just:just' ~cons:cons' context_mark
| A.EPureDefault e1 -> | A.EPureDefault e1 ->
let inner_ty = unionfind ~pos:e1 (TAny (Any.fresh ())) in let inner_ty = unionfind ~pos:e1 (TAny (Any.fresh ())) in
let mark = let mark =
mark_with_tau_and_unify (unionfind ~pos:e1 (TDefault inner_ty)) mark_with_tau_and_unify (unionfind ~pos:e1 (TDefault inner_ty))
in in
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env inner_ty e1 in let e1' = typecheck_expr_top_down ctx env inner_ty e1 in
Expr.epuredefault e1' mark Expr.epuredefault e1' mark
| A.EIfThenElse { cond; etrue = et; efalse = ef } -> | A.EIfThenElse { cond; etrue = et; efalse = ef } ->
let et' = typecheck_expr_top_down ~leave_unresolved ctx env tau et in let et' = typecheck_expr_top_down ctx env tau et in
let ef' = typecheck_expr_top_down ~leave_unresolved ctx env tau ef in let ef' = typecheck_expr_top_down ctx env tau ef in
let cond' = let cond' =
typecheck_expr_top_down ~leave_unresolved ctx env typecheck_expr_top_down ctx env (unionfind ~pos:cond (TLit TBool)) cond
(unionfind ~pos:cond (TLit TBool))
cond
in in
Expr.eifthenelse cond' et' ef' context_mark Expr.eifthenelse cond' et' ef' context_mark
| A.EAssert e1 -> | A.EAssert e1 ->
let mark = mark_with_tau_and_unify (unionfind (TLit TUnit)) in let mark = mark_with_tau_and_unify (unionfind (TLit TUnit)) in
let e1' = let e1' =
typecheck_expr_top_down ~leave_unresolved ctx env typecheck_expr_top_down ctx env (unionfind ~pos:e1 (TLit TBool)) e1
(unionfind ~pos:e1 (TLit TBool))
e1
in in
Expr.eassert e1' mark Expr.eassert e1' mark
| A.EEmptyError -> | A.EEmptyError ->
Expr.eemptyerror (ty_mark (TDefault (unionfind (TAny (Any.fresh ()))))) Expr.eemptyerror (ty_mark (TDefault (unionfind (TAny (Any.fresh ())))))
| A.EErrorOnEmpty e1 -> | A.EErrorOnEmpty e1 ->
let tau' = unionfind (TDefault tau) in let tau' = unionfind (TDefault tau) in
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env tau' e1 in let e1' = typecheck_expr_top_down ctx env tau' e1 in
Expr.eerroronempty e1' context_mark Expr.eerroronempty e1' context_mark
| A.EArray es -> | A.EArray es ->
let cell_type = unionfind (TAny (Any.fresh ())) in let cell_type = unionfind (TAny (Any.fresh ())) in
let mark = mark_with_tau_and_unify (unionfind (TArray cell_type)) in let mark = mark_with_tau_and_unify (unionfind (TArray cell_type)) in
let es' = let es' = List.map (typecheck_expr_top_down ctx env cell_type) es in
List.map (typecheck_expr_top_down ~leave_unresolved ctx env cell_type) es
in
Expr.earray es' mark Expr.earray es' mark
| A.ECustom { obj; targs; tret } -> | A.ECustom { obj; targs; tret } ->
let mark = let mark =
@ -920,42 +919,36 @@ let wrap_expr ctx f e =
(** {1 API} *) (** {1 API} *)
let get_ty_mark ~leave_unresolved (A.Custom { A.custom = uf; pos }) = let get_ty_mark ~flags (A.Custom { A.custom = uf; pos }) =
A.Typed { ty = typ_to_ast ~leave_unresolved uf; pos } A.Typed { ty = typ_to_ast ~flags uf; pos }
let expr_raw let expr_raw
(type a) (type a)
~(leave_unresolved : resolving_strategy)
(ctx : A.decl_ctx) (ctx : A.decl_ctx)
?(env = Env.empty ctx) ?(env = Env.empty ctx)
?(typ : A.typ option) ?(typ : A.typ option)
(e : (a, 'm) A.gexpr) : (a, unionfind_typ A.custom) A.gexpr = (e : (a, 'm) A.gexpr) : (a, unionfind_typ A.custom) A.gexpr =
let fty = let fty =
match typ with match typ with
| None -> typecheck_expr_bottom_up ~leave_unresolved ctx env | None -> typecheck_expr_bottom_up ctx env
| Some typ -> | Some typ -> typecheck_expr_top_down ctx env (ast_to_typ typ)
typecheck_expr_top_down ~leave_unresolved ctx env (ast_to_typ typ)
in in
wrap_expr ctx fty e wrap_expr ctx fty e
let check_expr ~leave_unresolved ctx ?env ?typ e = let check_expr ctx ?env ?typ e =
Expr.map_marks Expr.map_marks
~f:(fun (Custom { pos; _ }) -> A.Untyped { pos }) ~f:(fun (Custom { pos; _ }) -> A.Untyped { pos })
(expr_raw ctx ~leave_unresolved ?env ?typ e) (expr_raw ctx ?env ?typ e)
(* Infer the type of an expression *) (* Infer the type of an expression *)
let expr ~leave_unresolved ctx ?env ?typ e = let expr ctx ?(env = Env.empty ctx) ?typ e =
Expr.map_marks Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) (expr_raw ctx ~env ?typ e)
~f:(get_ty_mark ~leave_unresolved)
(expr_raw ~leave_unresolved ctx ?env ?typ e)
let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr = let rec scope_body_expr ctx env ty_out body_expr =
match body_expr with match body_expr with
| A.Result e -> | A.Result e ->
let e' = let e' = wrap_expr ctx (typecheck_expr_top_down ctx env ty_out) e in
wrap_expr ctx (typecheck_expr_top_down ~leave_unresolved ctx env ty_out) e let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in
in
let e' = Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e' in
Bindlib.box_apply (fun e -> A.Result e) (Expr.Box.lift e') Bindlib.box_apply (fun e -> A.Result e) (Expr.Box.lift e')
| A.ScopeLet | A.ScopeLet
{ {
@ -966,9 +959,7 @@ let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
scope_let_pos; scope_let_pos;
} -> } ->
let ty_e = ast_to_typ scope_let_typ in let ty_e = ast_to_typ scope_let_typ in
let e = let e = wrap_expr ctx (typecheck_expr_bottom_up ctx env) e0 in
wrap_expr ctx (typecheck_expr_bottom_up ~leave_unresolved ctx env) e0
in
wrap ctx (fun t -> unify ctx e0 (ty e) t) ty_e; wrap ctx (fun t -> unify ctx e0 (ty e) t) ty_e;
(* We could use [typecheck_expr_top_down] rather than this manual (* We could use [typecheck_expr_top_down] rather than this manual
unification, but we get better messages with this order of the [unify] unification, but we get better messages with this order of the [unify]
@ -976,7 +967,7 @@ let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
inferred. *) inferred. *)
let var, next = Bindlib.unbind scope_let_next in let var, next = Bindlib.unbind scope_let_next in
let env = Env.add var ty_e env in let env = Env.add var ty_e env in
let next = scope_body_expr ~leave_unresolved ctx env ty_out next in let next = scope_body_expr ctx env ty_out next in
let scope_let_next = Bindlib.bind_var (Var.translate var) next in let scope_let_next = Bindlib.bind_var (Var.translate var) next in
Bindlib.box_apply2 Bindlib.box_apply2
(fun scope_let_expr scope_let_next -> (fun scope_let_expr scope_let_next ->
@ -985,16 +976,16 @@ let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
scope_let_kind; scope_let_kind;
scope_let_typ = scope_let_typ =
(match Mark.remove scope_let_typ with (match Mark.remove scope_let_typ with
| TAny -> typ_to_ast ~leave_unresolved (ty e) | TAny -> typ_to_ast ~flags:env.flags (ty e)
| _ -> scope_let_typ); | _ -> scope_let_typ);
scope_let_expr; scope_let_expr;
scope_let_next; scope_let_next;
scope_let_pos; scope_let_pos;
}) })
(Expr.Box.lift (Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e)) (Expr.Box.lift (Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e))
scope_let_next scope_let_next
let scope_body ~leave_unresolved ctx env body = let scope_body ctx env body =
let get_pos struct_name = Mark.get (A.StructName.get_info struct_name) in let get_pos struct_name = Mark.get (A.StructName.get_info struct_name) in
let struct_ty struct_name = let struct_ty struct_name =
UnionFind.make (Mark.add (get_pos struct_name) (TStruct struct_name)) UnionFind.make (Mark.add (get_pos struct_name) (TStruct struct_name))
@ -1003,7 +994,7 @@ let scope_body ~leave_unresolved ctx env body =
let ty_out = struct_ty body.A.scope_body_output_struct in let ty_out = struct_ty body.A.scope_body_output_struct in
let var, e = Bindlib.unbind body.A.scope_body_expr in let var, e = Bindlib.unbind body.A.scope_body_expr in
let env = Env.add var ty_in env in let env = Env.add var ty_in env in
let e' = scope_body_expr ~leave_unresolved ctx env ty_out e in let e' = scope_body_expr ctx env ty_out e in
( Bindlib.box_apply ( Bindlib.box_apply
(fun scope_body_expr -> { body with scope_body_expr }) (fun scope_body_expr -> { body with scope_body_expr })
(Bindlib.bind_var (Var.translate var) e'), (Bindlib.bind_var (Var.translate var) e'),
@ -1012,35 +1003,33 @@ let scope_body ~leave_unresolved ctx env body =
(get_pos body.A.scope_body_output_struct) (get_pos body.A.scope_body_output_struct)
(TArrow ([ty_in], ty_out))) ) (TArrow ([ty_in], ty_out))) )
let rec scopes ~leave_unresolved ctx env = function let rec scopes ctx env = function
| A.Nil -> Bindlib.box A.Nil, env | A.Nil -> Bindlib.box A.Nil, env
| A.Cons (item, next_bind) -> | A.Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in let var, next = Bindlib.unbind next_bind in
let env, def = let env, def =
match item with match item with
| A.ScopeDef (name, body) -> | A.ScopeDef (name, body) ->
let body_e, ty_scope = scope_body ~leave_unresolved ctx env body in let body_e, ty_scope = scope_body ctx env body in
( Env.add var ty_scope env, ( Env.add var ty_scope env,
Bindlib.box_apply (fun body -> A.ScopeDef (name, body)) body_e ) Bindlib.box_apply (fun body -> A.ScopeDef (name, body)) body_e )
| A.Topdef (name, typ, e) -> | A.Topdef (name, typ, e) ->
let e' = expr_raw ~leave_unresolved ctx ~env ~typ e in let e' = expr_raw ctx ~env ~typ e in
let (A.Custom { custom = uf; _ }) = Mark.get e' in let (A.Custom { custom = uf; _ }) = Mark.get e' in
let e' = Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e' in let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in
( Env.add var uf env, ( Env.add var uf env,
Bindlib.box_apply Bindlib.box_apply
(fun e -> A.Topdef (name, Expr.ty e', e)) (fun e -> A.Topdef (name, Expr.ty e', e))
(Expr.Box.lift e') ) (Expr.Box.lift e') )
in in
let next', env = scopes ~leave_unresolved ctx env next in let next', env = scopes ctx env next in
let next_bind' = Bindlib.bind_var (Var.translate var) next' in let next_bind' = Bindlib.bind_var (Var.translate var) next' in
( Bindlib.box_apply2 (fun item next -> A.Cons (item, next)) def next_bind', ( Bindlib.box_apply2 (fun item next -> A.Cons (item, next)) def next_bind',
env ) env )
let program ~leave_unresolved prg = let program ?fail_on_any ?assume_op_types prg =
let code_items, new_env = let env = Env.empty ?fail_on_any ?assume_op_types prg.A.decl_ctx in
scopes ~leave_unresolved prg.A.decl_ctx (Env.empty prg.A.decl_ctx) let code_items, new_env = scopes prg.A.decl_ctx env prg.A.code_items in
prg.A.code_items
in
{ {
A.lang = prg.lang; A.lang = prg.lang;
A.module_name = prg.A.module_name; A.module_name = prg.A.module_name;
@ -1055,7 +1044,7 @@ let program ~leave_unresolved prg =
(fun f_name (t : A.typ) -> (fun f_name (t : A.typ) ->
match Mark.remove t with match Mark.remove t with
| TAny -> | TAny ->
typ_to_ast ~leave_unresolved typ_to_ast ~flags:env.flags
(A.StructField.Map.find f_name (A.StructField.Map.find f_name
(A.StructName.Map.find s_name new_env.structs)) (A.StructName.Map.find s_name new_env.structs))
| _ -> t) | _ -> t)
@ -1068,7 +1057,7 @@ let program ~leave_unresolved prg =
(fun cons_name (t : A.typ) -> (fun cons_name (t : A.typ) ->
match Mark.remove t with match Mark.remove t with
| TAny -> | TAny ->
typ_to_ast ~leave_unresolved typ_to_ast ~flags:env.flags
(A.EnumConstructor.Map.find cons_name (A.EnumConstructor.Map.find cons_name
(A.EnumName.Map.find e_name new_env.enums)) (A.EnumName.Map.find e_name new_env.enums))
| _ -> t) | _ -> t)

View File

@ -22,7 +22,17 @@ open Definitions
module Env : sig module Env : sig
type 'e t type 'e t
val empty : decl_ctx -> 'e t val empty : ?fail_on_any:bool -> ?assume_op_types:bool -> decl_ctx -> 'e t
(** The [~fail_on_any] labeled parameter controls the behavior of the typer in
the case where polymorphic expressions are still found after typing: if
[false], it allows them (giving them [TAny] and losing typing
information); if set to [true] (the default), it aborts.
The [~assume_op_types] flag (default false) ignores the expected built-in
types of polymorphic operators, and will assume correct the type
information included in [EAppOp] nodes. This is useful after
monomorphisation, which changes the expected types for these operators. *)
val add_var : 'e Var.t -> typ -> 'e t -> 'e t val add_var : 'e Var.t -> typ -> 'e t -> 'e t
val add_toplevel_var : TopdefName.t -> typ -> 'e t -> 'e t val add_toplevel_var : TopdefName.t -> typ -> 'e t -> 'e t
val add_scope_var : ScopeVar.t -> typ -> 'e t -> 'e t val add_scope_var : ScopeVar.t -> typ -> 'e t -> 'e t
@ -40,15 +50,7 @@ module Env : sig
(** For debug purposes *) (** For debug purposes *)
end end
(** In the following functions, the [~leave_unresolved] labeled parameter
controls the behavior of the typer in the case where polymorphic expressions
are still found after typing: if set to [LeaveAny], it allows them (giving
them [TAny] and losing typing information); if set to [ErrorOnAny], it
aborts. *)
type resolving_strategy = LeaveAny | ErrorOnAny
val expr : val expr :
leave_unresolved:resolving_strategy ->
decl_ctx -> decl_ctx ->
?env:'e Env.t -> ?env:'e Env.t ->
?typ:typ -> ?typ:typ ->
@ -75,11 +77,10 @@ val expr :
application, taking de-tuplification into account. application, taking de-tuplification into account.
- [TAny] appearing within nodes are refined to more precise types, e.g. on - [TAny] appearing within nodes are refined to more precise types, e.g. on
`EAbs` nodes (but be careful with this, it may only work for specific `EAbs` nodes (but be careful with this, it may only work for specific
structures of generated code ; [~leave_unresolved:false] checks that it structures of generated code ; having [~fail_on_any:true] set in the
didn't cause problems) *) environment (this is the default) checks that it didn't cause problems) *)
val check_expr : val check_expr :
leave_unresolved:resolving_strategy ->
decl_ctx -> decl_ctx ->
?env:'e Env.t -> ?env:'e Env.t ->
?typ:typ -> ?typ:typ ->
@ -91,7 +92,8 @@ val check_expr :
information, e.g. any [TAny] appearing in the AST is replaced) *) information, e.g. any [TAny] appearing in the AST is replaced) *)
val program : val program :
leave_unresolved:resolving_strategy -> ?fail_on_any:bool ->
?assume_op_types:bool ->
('a, 'm) gexpr program -> ('a, 'm) gexpr program ->
('a, typed) gexpr program ('a, typed) gexpr program
(** Typing on whole programs (as defined in Shared_ast.program, i.e. for the (** Typing on whole programs (as defined in Shared_ast.program, i.e. for the