mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Typing: add a "assume operator types" mode
This allows for retyping after monomorphisation: a new function just extracts the return type of the operator, without checking the operand types. Also to avoid multiplying function arguments around the typer, the flags have been gathered in a record that is included in the typing environment; it's ok to give them default values as long as these are the strictest.
This commit is contained in:
parent
ae89c870c1
commit
a56d95d790
@ -23,7 +23,7 @@ let expr ctx env e =
|
||||
[Some] *)
|
||||
(* Intermediate unboxings are fine since the [check_expr] will rebox in
|
||||
depth *)
|
||||
Typing.check_expr ~leave_unresolved:ErrorOnAny ctx ~env (Expr.unbox e)
|
||||
Typing.check_expr ctx ~env (Expr.unbox e)
|
||||
|
||||
let rule ctx env rule =
|
||||
let env =
|
||||
|
@ -192,7 +192,7 @@ module Passes = struct
|
||||
match typed with
|
||||
| Typed _ -> (
|
||||
Message.emit_debug "Typechecking again...";
|
||||
try Typing.program ~leave_unresolved:ErrorOnAny prg
|
||||
try Typing.program prg
|
||||
with Message.CompilerError error_content ->
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
Printexc.raise_with_backtrace
|
||||
@ -257,7 +257,7 @@ module Passes = struct
|
||||
let prg =
|
||||
if not closure_conversion then (
|
||||
Message.emit_debug "Retyping lambda calculus...";
|
||||
Typing.program ~leave_unresolved:LeaveAny prg)
|
||||
Typing.program ~fail_on_any:false prg)
|
||||
else (
|
||||
Message.emit_debug "Performing closure conversion...";
|
||||
let prg = Lcalc.Closure_conversion.closure_conversion prg in
|
||||
@ -268,16 +268,15 @@ module Passes = struct
|
||||
else prg
|
||||
in
|
||||
Message.emit_debug "Retyping lambda calculus...";
|
||||
Typing.program ~leave_unresolved:LeaveAny prg)
|
||||
Typing.program ~fail_on_any:false prg)
|
||||
in
|
||||
let prg, type_ordering =
|
||||
if monomorphize_types then (
|
||||
Message.emit_debug "Monomorphizing types...";
|
||||
Lcalc.Monomorphize.program prg
|
||||
(* (* FIXME: typing no longer works after monomorphisation, it would
|
||||
* need special operator handling for arrays and options *)
|
||||
* Message.emit_debug "Retyping lambda calculus...";
|
||||
* let prg = Typing.program ~leave_unresolved:LeaveAny prg in *))
|
||||
let prg, type_ordering = Lcalc.Monomorphize.program prg in
|
||||
Message.emit_debug "Retyping lambda calculus...";
|
||||
let prg = Typing.program ~fail_on_any:false ~assume_op_types:true prg in
|
||||
prg, type_ordering)
|
||||
else prg, type_ordering
|
||||
in
|
||||
prg, type_ordering
|
||||
@ -556,10 +555,7 @@ module Commands = struct
|
||||
|
||||
(* Additionally, we might want to check the invariants. *)
|
||||
if check_invariants then (
|
||||
let prg =
|
||||
Shared_ast.Typing.program ~leave_unresolved:ErrorOnAny
|
||||
(Program.untype prg)
|
||||
in
|
||||
let prg = Shared_ast.Typing.program prg in
|
||||
Message.emit_debug "Checking invariants...";
|
||||
if Dcalc.Invariants.check_all_invariants prg then
|
||||
Message.emit_result "All invariant checks passed"
|
||||
|
@ -67,15 +67,11 @@ type 'm program = {
|
||||
|
||||
let type_rule decl_ctx env = function
|
||||
| Definition (loc, typ, io, expr) ->
|
||||
let expr' =
|
||||
Typing.expr ~leave_unresolved:ErrorOnAny decl_ctx ~env ~typ expr
|
||||
in
|
||||
let expr' = Typing.expr decl_ctx ~env ~typ expr in
|
||||
Definition (loc, typ, io, Expr.unbox expr')
|
||||
| Assertion expr ->
|
||||
let typ = Mark.add (Expr.pos expr) (TLit TBool) in
|
||||
let expr' =
|
||||
Typing.expr ~leave_unresolved:ErrorOnAny decl_ctx ~env ~typ expr
|
||||
in
|
||||
let expr' = Typing.expr decl_ctx ~env ~typ expr in
|
||||
Assertion (Expr.unbox expr')
|
||||
| Call (sc_name, ssc_name, m) ->
|
||||
let pos = Expr.mark_pos m in
|
||||
@ -118,10 +114,7 @@ let type_program (type m) (prg : m program) : typed program =
|
||||
let program_topdefs =
|
||||
TopdefName.Map.map
|
||||
(fun (expr, typ) ->
|
||||
( Expr.unbox
|
||||
(Typing.expr prg.program_ctx ~leave_unresolved:ErrorOnAny ~env ~typ
|
||||
expr),
|
||||
typ ))
|
||||
Expr.unbox (Typing.expr prg.program_ctx ~env ~typ expr), typ)
|
||||
prg.program_topdefs
|
||||
in
|
||||
let program_scopes =
|
||||
|
@ -20,7 +20,7 @@
|
||||
open Catala_utils
|
||||
module A = Definitions
|
||||
|
||||
type resolving_strategy = LeaveAny | ErrorOnAny
|
||||
type flags = { fail_on_any : bool; assume_op_types : bool }
|
||||
|
||||
module Any =
|
||||
Uid.Make
|
||||
@ -54,9 +54,8 @@ and naked_typ =
|
||||
| TAny of Any.t
|
||||
| TClosureEnv
|
||||
|
||||
let rec typ_to_ast ~(leave_unresolved : resolving_strategy) (ty : unionfind_typ)
|
||||
: A.typ =
|
||||
let typ_to_ast = typ_to_ast ~leave_unresolved in
|
||||
let rec typ_to_ast ~(flags : flags) (ty : unionfind_typ) : A.typ =
|
||||
let typ_to_ast = typ_to_ast ~flags in
|
||||
let ty, pos = UnionFind.get (UnionFind.find ty) in
|
||||
match ty with
|
||||
| TLit l -> A.TLit l, pos
|
||||
@ -67,15 +66,14 @@ let rec typ_to_ast ~(leave_unresolved : resolving_strategy) (ty : unionfind_typ)
|
||||
| TArrow (t1, t2) -> A.TArrow (List.map typ_to_ast t1, typ_to_ast t2), pos
|
||||
| TArray t1 -> A.TArray (typ_to_ast t1), pos
|
||||
| TDefault t1 -> A.TDefault (typ_to_ast t1), pos
|
||||
| TAny _ -> (
|
||||
match leave_unresolved with
|
||||
| LeaveAny -> A.TAny, pos
|
||||
| ErrorOnAny ->
|
||||
| TAny _ ->
|
||||
if flags.fail_on_any then
|
||||
(* No polymorphism in Catala: type inference should return full types
|
||||
without wildcards, and this function is used to recover the types after
|
||||
typing. *)
|
||||
Message.raise_spanned_error pos
|
||||
"Internal error: typing at this point could not be resolved")
|
||||
"Internal error: typing at this point could not be resolved"
|
||||
else A.TAny, pos
|
||||
| TClosureEnv -> TClosureEnv, pos
|
||||
|
||||
let rec ast_to_typ (ty : A.typ) : unionfind_typ =
|
||||
@ -321,8 +319,39 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) :
|
||||
in
|
||||
Lazy.force ty
|
||||
|
||||
(* Just returns the return type of the operator, assuming the operand types are
|
||||
known. Less trict, but useful on monomorphised code where the operators no
|
||||
longer have their standard types *)
|
||||
let polymorphic_op_return_type
|
||||
ctx
|
||||
e
|
||||
(op : Operator.polymorphic A.operator Mark.pos)
|
||||
(targs : unionfind_typ list) : unionfind_typ =
|
||||
let open Operator in
|
||||
let pos = Mark.get op in
|
||||
let uf t = UnionFind.make (t, pos) in
|
||||
let any _ = uf (TAny (Any.fresh ())) in
|
||||
let return_type tf arity =
|
||||
let tret = any () in
|
||||
unify ctx e tf (UnionFind.make (TArrow (List.init arity any, tret), pos));
|
||||
tret
|
||||
in
|
||||
match Mark.remove op, targs with
|
||||
| Fold, [_; tau; _] -> tau
|
||||
| Eq, _ -> uf (TLit TBool)
|
||||
| Map, [tf; _] -> uf (TArray (return_type tf 1))
|
||||
| Map2, [tf; _; _] -> uf (TArray (return_type tf 2))
|
||||
| (Filter | Reduce | Concat), [_; tau] -> tau
|
||||
| Log (PosRecordIfTrueBool, _), _ -> uf (TLit TBool)
|
||||
| Log _, [tau] -> tau
|
||||
| Length, _ -> uf (TLit TInt)
|
||||
| (HandleDefault | HandleDefaultOpt), [_; _; tf] -> return_type tf 1
|
||||
| ToClosureEnv, _ -> uf TClosureEnv
|
||||
| FromClosureEnv, _ -> any ()
|
||||
| _ -> Message.raise_spanned_error pos "Mismatched operator arguments"
|
||||
|
||||
let resolve_overload_ret_type
|
||||
~leave_unresolved
|
||||
~flags
|
||||
(ctx : A.decl_ctx)
|
||||
e
|
||||
(op : Operator.overloaded A.operator)
|
||||
@ -330,7 +359,7 @@ let resolve_overload_ret_type
|
||||
let op_ty =
|
||||
Operator.overload_type ctx
|
||||
(Mark.add (Expr.pos e) op)
|
||||
(List.map (typ_to_ast ~leave_unresolved) tys)
|
||||
(List.map (typ_to_ast ~flags) tys)
|
||||
in
|
||||
ast_to_typ (Type.arrow_return op_ty)
|
||||
|
||||
@ -338,6 +367,7 @@ let resolve_overload_ret_type
|
||||
|
||||
module Env = struct
|
||||
type 'e t = {
|
||||
flags : flags;
|
||||
structs : unionfind_typ A.StructField.Map.t A.StructName.Map.t;
|
||||
enums : unionfind_typ A.EnumConstructor.Map.t A.EnumName.Map.t;
|
||||
vars : ('e, unionfind_typ) Var.Map.t;
|
||||
@ -347,10 +377,14 @@ module Env = struct
|
||||
toplevel_vars : A.typ A.TopdefName.Map.t;
|
||||
}
|
||||
|
||||
let empty (decl_ctx : A.decl_ctx) =
|
||||
let empty
|
||||
?(fail_on_any = true)
|
||||
?(assume_op_types = false)
|
||||
(decl_ctx : A.decl_ctx) =
|
||||
(* We fill the environment initially with the structs and enums
|
||||
declarations *)
|
||||
{
|
||||
flags = { fail_on_any; assume_op_types };
|
||||
structs =
|
||||
A.StructName.Map.map
|
||||
(fun ty -> A.StructField.Map.map ast_to_typ ty)
|
||||
@ -423,29 +457,28 @@ let ty : (_, unionfind_typ A.custom) A.marked -> unionfind_typ =
|
||||
(** Infers the most permissive type from an expression *)
|
||||
let rec typecheck_expr_bottom_up :
|
||||
type a m.
|
||||
leave_unresolved:resolving_strategy ->
|
||||
A.decl_ctx ->
|
||||
(a, m) A.gexpr Env.t ->
|
||||
(a, m) A.gexpr ->
|
||||
(a, unionfind_typ A.custom) A.boxed_gexpr =
|
||||
fun ~leave_unresolved ctx env e ->
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
fun ctx env e ->
|
||||
typecheck_expr_top_down ctx env
|
||||
(UnionFind.make (add_pos e (TAny (Any.fresh ()))))
|
||||
e
|
||||
|
||||
(** Checks whether the expression can be typed with the provided type *)
|
||||
and typecheck_expr_top_down :
|
||||
type a m.
|
||||
leave_unresolved:resolving_strategy ->
|
||||
A.decl_ctx ->
|
||||
(a, m) A.gexpr Env.t ->
|
||||
unionfind_typ ->
|
||||
(a, m) A.gexpr ->
|
||||
(a, unionfind_typ A.custom) A.boxed_gexpr =
|
||||
fun ~leave_unresolved ctx env tau e ->
|
||||
fun ctx env tau e ->
|
||||
(* Message.emit_debug "Propagating type %a for naked_expr :@.@[<hov 2>%a@]"
|
||||
(format_typ ctx) tau Expr.format e; *)
|
||||
let pos_e = Expr.pos e in
|
||||
let flags = env.flags in
|
||||
let () =
|
||||
(* If there already is a type annotation on the given expr, ensure it
|
||||
matches *)
|
||||
@ -519,7 +552,7 @@ and typecheck_expr_top_down :
|
||||
A.StructField.Map.mapi
|
||||
(fun f_name f_e ->
|
||||
let f_ty = A.StructField.Map.find f_name str in
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env f_ty f_e)
|
||||
typecheck_expr_top_down ctx env f_ty f_e)
|
||||
fields
|
||||
in
|
||||
Expr.estruct ~name ~fields mark
|
||||
@ -530,8 +563,7 @@ and typecheck_expr_top_down :
|
||||
| None -> TAny (Any.fresh ())
|
||||
in
|
||||
let e_struct' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env (unionfind t_struct)
|
||||
e_struct
|
||||
typecheck_expr_top_down ctx env (unionfind t_struct) e_struct
|
||||
in
|
||||
let name =
|
||||
match UnionFind.get (ty e_struct') with
|
||||
@ -598,8 +630,7 @@ and typecheck_expr_top_down :
|
||||
in
|
||||
let mark = mark_with_tau_and_unify fld_ty in
|
||||
let e_struct' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(unionfind (TStruct name)) e_struct
|
||||
typecheck_expr_top_down ctx env (unionfind (TStruct name)) e_struct
|
||||
in
|
||||
Expr.estructaccess ~e:e_struct' ~field ~name mark
|
||||
| A.EInj { name; cons; e = e_enum }
|
||||
@ -607,23 +638,20 @@ and typecheck_expr_top_down :
|
||||
if Definitions.EnumConstructor.equal cons Expr.some_constr then
|
||||
let cell_type = unionfind (TAny (Any.fresh ())) in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TOption cell_type)) in
|
||||
let e_enum' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env cell_type e_enum
|
||||
in
|
||||
let e_enum' = typecheck_expr_top_down ctx env cell_type e_enum in
|
||||
Expr.einj ~name ~cons ~e:e_enum' mark
|
||||
else
|
||||
(* None constructor *)
|
||||
let cell_type = unionfind (TAny (Any.fresh ())) in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TOption cell_type)) in
|
||||
let e_enum' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(unionfind (TLit TUnit)) e_enum
|
||||
typecheck_expr_top_down ctx env (unionfind (TLit TUnit)) e_enum
|
||||
in
|
||||
Expr.einj ~name ~cons ~e:e_enum' mark
|
||||
| A.EInj { name; cons; e = e_enum } ->
|
||||
let mark = mark_with_tau_and_unify (unionfind (TEnum name)) in
|
||||
let e_enum' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
typecheck_expr_top_down ctx env
|
||||
(A.EnumConstructor.Map.find cons (A.EnumName.Map.find name env.enums))
|
||||
e_enum
|
||||
in
|
||||
@ -640,14 +668,14 @@ and typecheck_expr_top_down :
|
||||
in
|
||||
let t_ret = unionfind ~pos:e (TAny (Any.fresh ())) in
|
||||
let mark = mark_with_tau_and_unify t_ret in
|
||||
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env t_arg e1 in
|
||||
let e1' = typecheck_expr_top_down ctx env t_arg e1 in
|
||||
let cases =
|
||||
A.EnumConstructor.Map.merge
|
||||
(fun _ e e_ty ->
|
||||
match e, e_ty with
|
||||
| Some e, Some e_ty ->
|
||||
Some
|
||||
(typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(typecheck_expr_top_down ctx env
|
||||
(unionfind ~pos:e (TArrow ([e_ty], t_ret)))
|
||||
e)
|
||||
| _ -> assert false)
|
||||
@ -658,10 +686,7 @@ and typecheck_expr_top_down :
|
||||
let cases_ty = A.EnumName.Map.find name ctx.A.ctx_enums in
|
||||
let t_ret = unionfind ~pos:e1 (TAny (Any.fresh ())) in
|
||||
let mark = mark_with_tau_and_unify t_ret in
|
||||
let e1' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env (unionfind (TEnum name))
|
||||
e1
|
||||
in
|
||||
let e1' = typecheck_expr_top_down ctx env (unionfind (TEnum name)) e1 in
|
||||
let cases =
|
||||
A.EnumConstructor.Map.mapi
|
||||
(fun c_name e ->
|
||||
@ -670,7 +695,7 @@ and typecheck_expr_top_down :
|
||||
there is a change to allow for multiple arguments, it might be
|
||||
easier to use tuples directly. *)
|
||||
let e_ty = unionfind ~pos:e (TArrow ([ast_to_typ c_ty], t_ret)) in
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env e_ty e)
|
||||
typecheck_expr_top_down ctx env e_ty e)
|
||||
cases
|
||||
in
|
||||
Expr.ematch ~e:e1' ~name ~cases mark
|
||||
@ -683,17 +708,15 @@ and typecheck_expr_top_down :
|
||||
let args' =
|
||||
A.ScopeVar.Map.mapi
|
||||
(fun name ->
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
typecheck_expr_top_down ctx env
|
||||
(ast_to_typ (A.ScopeVar.Map.find name vars)))
|
||||
args
|
||||
in
|
||||
Expr.escopecall ~scope ~args:args' mark
|
||||
| A.ERaise ex -> Expr.eraise ex context_mark
|
||||
| A.ECatch { body; exn; handler } ->
|
||||
let body' = typecheck_expr_top_down ~leave_unresolved ctx env tau body in
|
||||
let handler' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env tau handler
|
||||
in
|
||||
let body' = typecheck_expr_top_down ctx env tau body in
|
||||
let handler' = typecheck_expr_top_down ctx env tau handler in
|
||||
Expr.ecatch body' exn handler' context_mark
|
||||
| A.EVar v ->
|
||||
let tau' =
|
||||
@ -732,9 +755,7 @@ and typecheck_expr_top_down :
|
||||
| A.ETuple es ->
|
||||
let tys = List.map (fun _ -> unionfind (TAny (Any.fresh ()))) es in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TTuple tys)) in
|
||||
let es' =
|
||||
List.map2 (typecheck_expr_top_down ~leave_unresolved ctx env) tys es
|
||||
in
|
||||
let es' = List.map2 (typecheck_expr_top_down ctx env) tys es in
|
||||
Expr.etuple es' mark
|
||||
| A.ETupleAccess { e = e1; index; size } ->
|
||||
if index >= size then
|
||||
@ -745,11 +766,7 @@ and typecheck_expr_top_down :
|
||||
(List.init size (fun n ->
|
||||
if n = index then tau else unionfind ~pos:e1 (TAny (Any.fresh ()))))
|
||||
in
|
||||
let e1' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(unionfind ~pos:e1 tuple_ty)
|
||||
e1
|
||||
in
|
||||
let e1' = typecheck_expr_top_down ctx env (unionfind ~pos:e1 tuple_ty) e1 in
|
||||
Expr.etupleaccess ~e:e1' ~index ~size context_mark
|
||||
| A.EAbs { binder; tys = t_args } ->
|
||||
if Bindlib.mbinder_arity binder <> List.length t_args then
|
||||
@ -769,11 +786,9 @@ and typecheck_expr_top_down :
|
||||
(fun env x tau_arg -> Env.add x tau_arg env)
|
||||
env (Array.to_list xs) tau_args
|
||||
in
|
||||
let body' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env t_ret body
|
||||
in
|
||||
let body' = typecheck_expr_top_down ctx env t_ret body in
|
||||
let binder' = Bindlib.bind_mvar xs' (Expr.Box.lift body') in
|
||||
Expr.eabs binder' (List.map (typ_to_ast ~leave_unresolved) tau_args) mark
|
||||
Expr.eabs binder' (List.map (typ_to_ast ~flags) tau_args) mark
|
||||
| A.EApp { f = e1; args; tys } ->
|
||||
(* Here we type the arguments first (in order), to ensure we know the types
|
||||
of the arguments if [f] is [EAbs] before disambiguation. This is also the
|
||||
@ -783,9 +798,7 @@ and typecheck_expr_top_down :
|
||||
| [] -> List.map (fun _ -> unionfind (TAny (Any.fresh ()))) args
|
||||
| tys -> List.map ast_to_typ tys
|
||||
in
|
||||
let args' =
|
||||
List.map2 (typecheck_expr_top_down ~leave_unresolved ctx env) t_args args
|
||||
in
|
||||
let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
|
||||
let t_args =
|
||||
match t_args, tys with
|
||||
| [t], [] -> (
|
||||
@ -805,9 +818,9 @@ and typecheck_expr_top_down :
|
||||
t_args
|
||||
in
|
||||
let t_func = unionfind ~pos:e1 (TArrow (t_args, tau)) in
|
||||
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env t_func e1 in
|
||||
let e1' = typecheck_expr_top_down ctx env t_func e1 in
|
||||
Expr.eapp ~f:e1' ~args:args'
|
||||
~tys:(List.map (typ_to_ast ~leave_unresolved) t_args)
|
||||
~tys:(List.map (typ_to_ast ~flags) t_args)
|
||||
context_mark
|
||||
| A.EAppOp { op; tys; args } ->
|
||||
let t_args = List.map ast_to_typ tys in
|
||||
@ -818,87 +831,73 @@ and typecheck_expr_top_down :
|
||||
(* Type the operator first, then right-to-left: polymorphic operators
|
||||
are required to allow the resolution of all type variables this
|
||||
way *)
|
||||
unify ctx e (polymorphic_op_type (Mark.add pos_e op)) t_func;
|
||||
if not env.flags.assume_op_types then
|
||||
unify ctx e (polymorphic_op_type (Mark.add pos_e op)) t_func
|
||||
else
|
||||
unify ctx e
|
||||
(polymorphic_op_return_type ctx e (Mark.add pos_e op) t_args)
|
||||
tau;
|
||||
List.rev_map2
|
||||
(typecheck_expr_top_down ~leave_unresolved ctx env)
|
||||
(typecheck_expr_top_down ctx env)
|
||||
(List.rev t_args) (List.rev args))
|
||||
~overloaded:(fun op ->
|
||||
(* Typing the arguments first is required to resolve the operator *)
|
||||
let args' =
|
||||
List.map2
|
||||
(typecheck_expr_top_down ~leave_unresolved ctx env)
|
||||
t_args args
|
||||
in
|
||||
unify ctx e tau
|
||||
(resolve_overload_ret_type ~leave_unresolved ctx e op t_args);
|
||||
let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
|
||||
unify ctx e tau (resolve_overload_ret_type ~flags ctx e op t_args);
|
||||
args')
|
||||
~monomorphic:(fun op ->
|
||||
(* Here it doesn't matter but may affect the error messages *)
|
||||
unify ctx e
|
||||
(ast_to_typ (Operator.monomorphic_type (Mark.add pos_e op)))
|
||||
t_func;
|
||||
List.map2
|
||||
(typecheck_expr_top_down ~leave_unresolved ctx env)
|
||||
t_args args)
|
||||
List.map2 (typecheck_expr_top_down ctx env) t_args args)
|
||||
~resolved:(fun op ->
|
||||
(* This case should not fail *)
|
||||
unify ctx e
|
||||
(ast_to_typ (Operator.resolved_type (Mark.add pos_e op)))
|
||||
t_func;
|
||||
List.map2
|
||||
(typecheck_expr_top_down ~leave_unresolved ctx env)
|
||||
t_args args)
|
||||
List.map2 (typecheck_expr_top_down ctx env) t_args args)
|
||||
in
|
||||
(* All operator applications are monomorphised at this point *)
|
||||
let tys = List.map (typ_to_ast ~leave_unresolved) t_args in
|
||||
let tys = List.map (typ_to_ast ~flags) t_args in
|
||||
Expr.eappop ~op ~args ~tys context_mark
|
||||
| A.EDefault { excepts; just; cons } ->
|
||||
let cons' = typecheck_expr_top_down ~leave_unresolved ctx env tau cons in
|
||||
let cons' = typecheck_expr_top_down ctx env tau cons in
|
||||
let just' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(unionfind ~pos:just (TLit TBool))
|
||||
just
|
||||
in
|
||||
let excepts' =
|
||||
List.map (typecheck_expr_top_down ~leave_unresolved ctx env tau) excepts
|
||||
typecheck_expr_top_down ctx env (unionfind ~pos:just (TLit TBool)) just
|
||||
in
|
||||
let excepts' = List.map (typecheck_expr_top_down ctx env tau) excepts in
|
||||
Expr.edefault ~excepts:excepts' ~just:just' ~cons:cons' context_mark
|
||||
| A.EPureDefault e1 ->
|
||||
let inner_ty = unionfind ~pos:e1 (TAny (Any.fresh ())) in
|
||||
let mark =
|
||||
mark_with_tau_and_unify (unionfind ~pos:e1 (TDefault inner_ty))
|
||||
in
|
||||
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env inner_ty e1 in
|
||||
let e1' = typecheck_expr_top_down ctx env inner_ty e1 in
|
||||
Expr.epuredefault e1' mark
|
||||
| A.EIfThenElse { cond; etrue = et; efalse = ef } ->
|
||||
let et' = typecheck_expr_top_down ~leave_unresolved ctx env tau et in
|
||||
let ef' = typecheck_expr_top_down ~leave_unresolved ctx env tau ef in
|
||||
let et' = typecheck_expr_top_down ctx env tau et in
|
||||
let ef' = typecheck_expr_top_down ctx env tau ef in
|
||||
let cond' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(unionfind ~pos:cond (TLit TBool))
|
||||
cond
|
||||
typecheck_expr_top_down ctx env (unionfind ~pos:cond (TLit TBool)) cond
|
||||
in
|
||||
Expr.eifthenelse cond' et' ef' context_mark
|
||||
| A.EAssert e1 ->
|
||||
let mark = mark_with_tau_and_unify (unionfind (TLit TUnit)) in
|
||||
let e1' =
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env
|
||||
(unionfind ~pos:e1 (TLit TBool))
|
||||
e1
|
||||
typecheck_expr_top_down ctx env (unionfind ~pos:e1 (TLit TBool)) e1
|
||||
in
|
||||
Expr.eassert e1' mark
|
||||
| A.EEmptyError ->
|
||||
Expr.eemptyerror (ty_mark (TDefault (unionfind (TAny (Any.fresh ())))))
|
||||
| A.EErrorOnEmpty e1 ->
|
||||
let tau' = unionfind (TDefault tau) in
|
||||
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env tau' e1 in
|
||||
let e1' = typecheck_expr_top_down ctx env tau' e1 in
|
||||
Expr.eerroronempty e1' context_mark
|
||||
| A.EArray es ->
|
||||
let cell_type = unionfind (TAny (Any.fresh ())) in
|
||||
let mark = mark_with_tau_and_unify (unionfind (TArray cell_type)) in
|
||||
let es' =
|
||||
List.map (typecheck_expr_top_down ~leave_unresolved ctx env cell_type) es
|
||||
in
|
||||
let es' = List.map (typecheck_expr_top_down ctx env cell_type) es in
|
||||
Expr.earray es' mark
|
||||
| A.ECustom { obj; targs; tret } ->
|
||||
let mark =
|
||||
@ -920,42 +919,36 @@ let wrap_expr ctx f e =
|
||||
|
||||
(** {1 API} *)
|
||||
|
||||
let get_ty_mark ~leave_unresolved (A.Custom { A.custom = uf; pos }) =
|
||||
A.Typed { ty = typ_to_ast ~leave_unresolved uf; pos }
|
||||
let get_ty_mark ~flags (A.Custom { A.custom = uf; pos }) =
|
||||
A.Typed { ty = typ_to_ast ~flags uf; pos }
|
||||
|
||||
let expr_raw
|
||||
(type a)
|
||||
~(leave_unresolved : resolving_strategy)
|
||||
(ctx : A.decl_ctx)
|
||||
?(env = Env.empty ctx)
|
||||
?(typ : A.typ option)
|
||||
(e : (a, 'm) A.gexpr) : (a, unionfind_typ A.custom) A.gexpr =
|
||||
let fty =
|
||||
match typ with
|
||||
| None -> typecheck_expr_bottom_up ~leave_unresolved ctx env
|
||||
| Some typ ->
|
||||
typecheck_expr_top_down ~leave_unresolved ctx env (ast_to_typ typ)
|
||||
| None -> typecheck_expr_bottom_up ctx env
|
||||
| Some typ -> typecheck_expr_top_down ctx env (ast_to_typ typ)
|
||||
in
|
||||
wrap_expr ctx fty e
|
||||
|
||||
let check_expr ~leave_unresolved ctx ?env ?typ e =
|
||||
let check_expr ctx ?env ?typ e =
|
||||
Expr.map_marks
|
||||
~f:(fun (Custom { pos; _ }) -> A.Untyped { pos })
|
||||
(expr_raw ctx ~leave_unresolved ?env ?typ e)
|
||||
(expr_raw ctx ?env ?typ e)
|
||||
|
||||
(* Infer the type of an expression *)
|
||||
let expr ~leave_unresolved ctx ?env ?typ e =
|
||||
Expr.map_marks
|
||||
~f:(get_ty_mark ~leave_unresolved)
|
||||
(expr_raw ~leave_unresolved ctx ?env ?typ e)
|
||||
let expr ctx ?(env = Env.empty ctx) ?typ e =
|
||||
Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) (expr_raw ctx ~env ?typ e)
|
||||
|
||||
let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
|
||||
let rec scope_body_expr ctx env ty_out body_expr =
|
||||
match body_expr with
|
||||
| A.Result e ->
|
||||
let e' =
|
||||
wrap_expr ctx (typecheck_expr_top_down ~leave_unresolved ctx env ty_out) e
|
||||
in
|
||||
let e' = Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e' in
|
||||
let e' = wrap_expr ctx (typecheck_expr_top_down ctx env ty_out) e in
|
||||
let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in
|
||||
Bindlib.box_apply (fun e -> A.Result e) (Expr.Box.lift e')
|
||||
| A.ScopeLet
|
||||
{
|
||||
@ -966,9 +959,7 @@ let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
|
||||
scope_let_pos;
|
||||
} ->
|
||||
let ty_e = ast_to_typ scope_let_typ in
|
||||
let e =
|
||||
wrap_expr ctx (typecheck_expr_bottom_up ~leave_unresolved ctx env) e0
|
||||
in
|
||||
let e = wrap_expr ctx (typecheck_expr_bottom_up ctx env) e0 in
|
||||
wrap ctx (fun t -> unify ctx e0 (ty e) t) ty_e;
|
||||
(* We could use [typecheck_expr_top_down] rather than this manual
|
||||
unification, but we get better messages with this order of the [unify]
|
||||
@ -976,7 +967,7 @@ let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
|
||||
inferred. *)
|
||||
let var, next = Bindlib.unbind scope_let_next in
|
||||
let env = Env.add var ty_e env in
|
||||
let next = scope_body_expr ~leave_unresolved ctx env ty_out next in
|
||||
let next = scope_body_expr ctx env ty_out next in
|
||||
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
|
||||
Bindlib.box_apply2
|
||||
(fun scope_let_expr scope_let_next ->
|
||||
@ -985,16 +976,16 @@ let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
|
||||
scope_let_kind;
|
||||
scope_let_typ =
|
||||
(match Mark.remove scope_let_typ with
|
||||
| TAny -> typ_to_ast ~leave_unresolved (ty e)
|
||||
| TAny -> typ_to_ast ~flags:env.flags (ty e)
|
||||
| _ -> scope_let_typ);
|
||||
scope_let_expr;
|
||||
scope_let_next;
|
||||
scope_let_pos;
|
||||
})
|
||||
(Expr.Box.lift (Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e))
|
||||
(Expr.Box.lift (Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e))
|
||||
scope_let_next
|
||||
|
||||
let scope_body ~leave_unresolved ctx env body =
|
||||
let scope_body ctx env body =
|
||||
let get_pos struct_name = Mark.get (A.StructName.get_info struct_name) in
|
||||
let struct_ty struct_name =
|
||||
UnionFind.make (Mark.add (get_pos struct_name) (TStruct struct_name))
|
||||
@ -1003,7 +994,7 @@ let scope_body ~leave_unresolved ctx env body =
|
||||
let ty_out = struct_ty body.A.scope_body_output_struct in
|
||||
let var, e = Bindlib.unbind body.A.scope_body_expr in
|
||||
let env = Env.add var ty_in env in
|
||||
let e' = scope_body_expr ~leave_unresolved ctx env ty_out e in
|
||||
let e' = scope_body_expr ctx env ty_out e in
|
||||
( Bindlib.box_apply
|
||||
(fun scope_body_expr -> { body with scope_body_expr })
|
||||
(Bindlib.bind_var (Var.translate var) e'),
|
||||
@ -1012,35 +1003,33 @@ let scope_body ~leave_unresolved ctx env body =
|
||||
(get_pos body.A.scope_body_output_struct)
|
||||
(TArrow ([ty_in], ty_out))) )
|
||||
|
||||
let rec scopes ~leave_unresolved ctx env = function
|
||||
let rec scopes ctx env = function
|
||||
| A.Nil -> Bindlib.box A.Nil, env
|
||||
| A.Cons (item, next_bind) ->
|
||||
let var, next = Bindlib.unbind next_bind in
|
||||
let env, def =
|
||||
match item with
|
||||
| A.ScopeDef (name, body) ->
|
||||
let body_e, ty_scope = scope_body ~leave_unresolved ctx env body in
|
||||
let body_e, ty_scope = scope_body ctx env body in
|
||||
( Env.add var ty_scope env,
|
||||
Bindlib.box_apply (fun body -> A.ScopeDef (name, body)) body_e )
|
||||
| A.Topdef (name, typ, e) ->
|
||||
let e' = expr_raw ~leave_unresolved ctx ~env ~typ e in
|
||||
let e' = expr_raw ctx ~env ~typ e in
|
||||
let (A.Custom { custom = uf; _ }) = Mark.get e' in
|
||||
let e' = Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e' in
|
||||
let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in
|
||||
( Env.add var uf env,
|
||||
Bindlib.box_apply
|
||||
(fun e -> A.Topdef (name, Expr.ty e', e))
|
||||
(Expr.Box.lift e') )
|
||||
in
|
||||
let next', env = scopes ~leave_unresolved ctx env next in
|
||||
let next', env = scopes ctx env next in
|
||||
let next_bind' = Bindlib.bind_var (Var.translate var) next' in
|
||||
( Bindlib.box_apply2 (fun item next -> A.Cons (item, next)) def next_bind',
|
||||
env )
|
||||
|
||||
let program ~leave_unresolved prg =
|
||||
let code_items, new_env =
|
||||
scopes ~leave_unresolved prg.A.decl_ctx (Env.empty prg.A.decl_ctx)
|
||||
prg.A.code_items
|
||||
in
|
||||
let program ?fail_on_any ?assume_op_types prg =
|
||||
let env = Env.empty ?fail_on_any ?assume_op_types prg.A.decl_ctx in
|
||||
let code_items, new_env = scopes prg.A.decl_ctx env prg.A.code_items in
|
||||
{
|
||||
A.lang = prg.lang;
|
||||
A.module_name = prg.A.module_name;
|
||||
@ -1055,7 +1044,7 @@ let program ~leave_unresolved prg =
|
||||
(fun f_name (t : A.typ) ->
|
||||
match Mark.remove t with
|
||||
| TAny ->
|
||||
typ_to_ast ~leave_unresolved
|
||||
typ_to_ast ~flags:env.flags
|
||||
(A.StructField.Map.find f_name
|
||||
(A.StructName.Map.find s_name new_env.structs))
|
||||
| _ -> t)
|
||||
@ -1068,7 +1057,7 @@ let program ~leave_unresolved prg =
|
||||
(fun cons_name (t : A.typ) ->
|
||||
match Mark.remove t with
|
||||
| TAny ->
|
||||
typ_to_ast ~leave_unresolved
|
||||
typ_to_ast ~flags:env.flags
|
||||
(A.EnumConstructor.Map.find cons_name
|
||||
(A.EnumName.Map.find e_name new_env.enums))
|
||||
| _ -> t)
|
||||
|
@ -22,7 +22,17 @@ open Definitions
|
||||
module Env : sig
|
||||
type 'e t
|
||||
|
||||
val empty : decl_ctx -> 'e t
|
||||
val empty : ?fail_on_any:bool -> ?assume_op_types:bool -> decl_ctx -> 'e t
|
||||
(** The [~fail_on_any] labeled parameter controls the behavior of the typer in
|
||||
the case where polymorphic expressions are still found after typing: if
|
||||
[false], it allows them (giving them [TAny] and losing typing
|
||||
information); if set to [true] (the default), it aborts.
|
||||
|
||||
The [~assume_op_types] flag (default false) ignores the expected built-in
|
||||
types of polymorphic operators, and will assume correct the type
|
||||
information included in [EAppOp] nodes. This is useful after
|
||||
monomorphisation, which changes the expected types for these operators. *)
|
||||
|
||||
val add_var : 'e Var.t -> typ -> 'e t -> 'e t
|
||||
val add_toplevel_var : TopdefName.t -> typ -> 'e t -> 'e t
|
||||
val add_scope_var : ScopeVar.t -> typ -> 'e t -> 'e t
|
||||
@ -40,15 +50,7 @@ module Env : sig
|
||||
(** For debug purposes *)
|
||||
end
|
||||
|
||||
(** In the following functions, the [~leave_unresolved] labeled parameter
|
||||
controls the behavior of the typer in the case where polymorphic expressions
|
||||
are still found after typing: if set to [LeaveAny], it allows them (giving
|
||||
them [TAny] and losing typing information); if set to [ErrorOnAny], it
|
||||
aborts. *)
|
||||
type resolving_strategy = LeaveAny | ErrorOnAny
|
||||
|
||||
val expr :
|
||||
leave_unresolved:resolving_strategy ->
|
||||
decl_ctx ->
|
||||
?env:'e Env.t ->
|
||||
?typ:typ ->
|
||||
@ -75,11 +77,10 @@ val expr :
|
||||
application, taking de-tuplification into account.
|
||||
- [TAny] appearing within nodes are refined to more precise types, e.g. on
|
||||
`EAbs` nodes (but be careful with this, it may only work for specific
|
||||
structures of generated code ; [~leave_unresolved:false] checks that it
|
||||
didn't cause problems) *)
|
||||
structures of generated code ; having [~fail_on_any:true] set in the
|
||||
environment (this is the default) checks that it didn't cause problems) *)
|
||||
|
||||
val check_expr :
|
||||
leave_unresolved:resolving_strategy ->
|
||||
decl_ctx ->
|
||||
?env:'e Env.t ->
|
||||
?typ:typ ->
|
||||
@ -91,7 +92,8 @@ val check_expr :
|
||||
information, e.g. any [TAny] appearing in the AST is replaced) *)
|
||||
|
||||
val program :
|
||||
leave_unresolved:resolving_strategy ->
|
||||
?fail_on_any:bool ->
|
||||
?assume_op_types:bool ->
|
||||
('a, 'm) gexpr program ->
|
||||
('a, typed) gexpr program
|
||||
(** Typing on whole programs (as defined in Shared_ast.program, i.e. for the
|
||||
|
Loading…
Reference in New Issue
Block a user