typing.ml: fix use of bindlib

This commit is contained in:
Louis Gesbert 2022-06-23 14:04:51 +02:00
parent 67179a793c
commit 6cc2e9a07b
3 changed files with 186 additions and 161 deletions

View File

@ -138,6 +138,18 @@ module Infer = struct
| TArrow (t1, t2) -> TArrow (typ_to_ast t1, typ_to_ast t2), pos
| TAny _ -> TAny, pos
| TArray t1 -> TArray (typ_to_ast t1), pos
let rec ast_to_typ (ty : marked_typ) : unionfind_typ =
let ty' =
match Marked.unmark ty with
| TLit l -> TLit l
| TArrow (t1, t2) -> TArrow (ast_to_typ t1, ast_to_typ t2)
| TTuple (ts, s) -> TTuple (List.map (fun t -> ast_to_typ t) ts, s)
| TEnum (ts, e) -> TEnum (List.map (fun t -> ast_to_typ t) ts, e)
| TArray t -> TArray (ast_to_typ t)
| TAny -> TAny (Any.fresh ())
in
UnionFind.make (Marked.same_mark_as ty' ty)
end
type untyped = { pos : Pos.t } [@@ocaml.unboxed]
@ -483,16 +495,6 @@ let (make_abs : ('m expr, 'm) make_abs_sig) =
fun xs e taus mark ->
Bindlib.box_apply (fun b -> EAbs (b, taus), mark) (Bindlib.bind_mvar xs e)
let empty_thunked_term : untyped marked_expr =
let silent = new_var "_" in
Bindlib.unbox
(make_abs [| silent |]
(Bindlib.box
(ELit LEmptyError, Untyped { pos = Pos.no_pos }
: (untyped expr, untyped) marked))
[TLit TUnit, Pos.no_pos]
(Untyped { pos = Pos.no_pos }))
let make_app :
'm marked_expr Bindlib.box ->
'm marked_expr Bindlib.box list ->
@ -528,6 +530,37 @@ let map_mark2
| Untyped m1, Untyped m2 -> Untyped { pos = pos_f m1.pos m2.pos }
| Typed m1, Typed m2 -> Typed { pos = pos_f m1.pos m2.pos; ty = ty_f m1 m2 }
let fold_marks
(type m)
(pos_f : Pos.t list -> Pos.t)
(ty_f : typed list -> Infer.unionfind_typ)
(ms : m mark list) : m mark =
match ms with
| [] -> invalid_arg "Dcalc.Ast.fold_mark"
| Untyped _ :: _ as ms ->
Untyped { pos = pos_f (List.map (function Untyped { pos } -> pos) ms) }
| Typed _ :: _ ->
Typed
{
pos = pos_f (List.map (function Typed { pos; _ } -> pos) ms);
ty = ty_f (List.map (function Typed m -> m) ms);
}
let empty_thunked_term mark : 'm marked_expr =
let silent = new_var "_" in
let pos = mark_pos mark in
Bindlib.unbox
(make_abs [| silent |]
(Bindlib.box (ELit LEmptyError, mark))
[TLit TUnit, mark_pos mark]
(map_mark
(fun pos -> pos)
(fun ty ->
UnionFind.make
Infer.(
TArrow (UnionFind.make (TLit TUnit, pos), ty), mark_pos mark))
mark))
let (make_let_in : ('m expr, 'm) make_let_in_sig) =
fun x tau e1 e2 pos ->
let m_e1 = Marked.get_mark (Bindlib.unbox e1) in

View File

@ -122,6 +122,8 @@ module Infer: sig
val typ_to_ast : unionfind_typ -> marked_typ
val ast_to_typ : marked_typ -> unionfind_typ
end
type untyped = { pos : Pos.t } [@@unboxed]
@ -232,6 +234,7 @@ val ty: ('a, typed) marked -> typ
val with_ty: Infer.unionfind_typ -> ('a, 'm) marked -> ('a, typed) marked
val map_mark: (Pos.t -> Pos.t) -> (Infer.unionfind_typ -> Infer.unionfind_typ) -> 'm mark -> 'm mark
val map_mark2: (Pos.t -> Pos.t -> Pos.t) -> (typed -> typed -> Infer.unionfind_typ) -> 'm mark -> 'm mark -> 'm mark
val fold_marks: (Pos.t list -> Pos.t) -> (typed list -> Infer.unionfind_typ) -> 'm mark list -> 'm mark
val get_scope_body_mark: ('expr, 'm) scope_body -> 'm mark
(** {2 Boxed constructors} *)
@ -448,7 +451,7 @@ val make_let_in : ('m expr, 'm) make_let_in_sig
(**{2 Other}*)
val empty_thunked_term : untyped marked_expr
val empty_thunked_term : 'm mark -> 'm marked_expr
val is_value : 'm marked_expr -> bool
val equal_exprs : 'm marked_expr -> 'm marked_expr -> bool

View File

@ -185,28 +185,6 @@ let op_type (op : A.operator Marked.pos) : typ Marked.pos UnionFind.elem =
| Binop (Mult KDate) | Binop (Div KDate) | Unop (Minus KDate) ->
Errors.raise_spanned_error pos "This operator is not available!"
let rec ast_to_typ (ty : A.typ) : typ =
match ty with
| A.TLit l -> TLit l
| A.TArrow (t1, t2) ->
TArrow
( UnionFind.make (Marked.map_under_mark ast_to_typ t1),
UnionFind.make (Marked.map_under_mark ast_to_typ t2) )
| A.TTuple (ts, s) ->
TTuple
( List.map
(fun t -> UnionFind.make (Marked.map_under_mark ast_to_typ t))
ts,
s )
| A.TEnum (ts, e) ->
TEnum
( List.map
(fun t -> UnionFind.make (Marked.map_under_mark ast_to_typ t))
ts,
e )
| A.TArray t -> TArray (UnionFind.make (Marked.map_under_mark ast_to_typ t))
| A.TAny -> TAny (Any.fresh ())
(** {1 Double-directed typing} *)
type env = typ Marked.pos UnionFind.elem A.VarMap.t
@ -215,19 +193,34 @@ let add_pos e ty = Marked.mark (A.pos e) ty
let ty (_, A.Typed { ty; _ }) = ty
(** used to convert an [untyped expr var] into a [typed expr var] *)
let translate_var v = Bindlib.copy_var v (fun x -> A.EVar x) (Bindlib.name_of v)
let translate_var: 'm1 A.var -> 'm2 A.var =
fun v -> Bindlib.copy_var v (fun x -> A.EVar x) (Bindlib.name_of v)
let (let+) x f = Bindlib.box_apply f x
let (and+) x1 x2 = Bindlib.box_pair x1 x2
let bmap f es =
List.fold_right
(fun e acc -> let+ e' = f e and+ acc = acc in e' :: acc)
es
(Bindlib.box [])
let bmap2 f es xs =
List.fold_right2
(fun e x acc ->
let+ e' = f e x
and+ acc = acc
in e' :: acc)
es xs
(Bindlib.box [])
let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e)
(** Infers the most permissive type from an expression *)
let rec typecheck_expr_bottom_up: 'm .
Ast.decl_ctx ->
env ->
'm A.marked_expr -> A.typed_expr = fun
let rec typecheck_expr_bottom_up
(ctx : Ast.decl_ctx)
(env : env)
(e : 'm A.marked_expr) : A.typed_expr ->
(* (ctx : Ast.decl_ctx)
* (env : env)
* (e : 'm A.marked_expr) : A.typed_expr = *)
(e : 'm A.marked_expr) : A.typed_expr Bindlib.box =
(* Cli.debug_print (Format.asprintf "Looking for type of %a"
(Print.format_expr ctx) e); *)
try
@ -241,29 +234,32 @@ let rec typecheck_expr_bottom_up: 'm .
| A.EVar v -> begin
match A.VarMap.find_opt (A.Var.t v) env with
| Some t ->
mark (EVar (translate_var v)) t
let+ v' = Bindlib.box_var (translate_var v) in
mark v' t
| None ->
Errors.raise_spanned_error (A.pos e)
"Variable not found in the current context"
end
| A.ELit (LBool _) as e1 -> mark_with_uf e1 (TLit TBool)
| A.ELit (LInt _) as e1 -> mark_with_uf e1 (TLit TInt)
| A.ELit (LRat _) as e1 -> mark_with_uf e1 (TLit TRat)
| A.ELit (LMoney _) as e1 -> mark_with_uf e1 (TLit TMoney)
| A.ELit (LDate _) as e1 -> mark_with_uf e1 (TLit TDate)
| A.ELit (LDuration _) as e1 -> mark_with_uf e1 (TLit TDuration)
| A.ELit LUnit as e1 -> mark_with_uf e1 (TLit TUnit)
| A.ELit LEmptyError as e1 -> mark_with_uf e1 (TAny (Any.fresh ()))
| A.ELit (LBool _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TBool)
| A.ELit (LInt _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TInt)
| A.ELit (LRat _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TRat)
| A.ELit (LMoney _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TMoney)
| A.ELit (LDate _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TDate)
| A.ELit (LDuration _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TDuration)
| A.ELit LUnit as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TUnit)
| A.ELit LEmptyError as e1 -> Bindlib.box @@ mark_with_uf e1 (TAny (Any.fresh ()))
| A.ETuple (es, s) ->
let es = List.map (typecheck_expr_bottom_up ctx env) es in
let+ es =
bmap
(typecheck_expr_bottom_up ctx env)
es
in
mark_with_uf (ETuple (es, s)) (TTuple (List.map ty es, s))
| A.ETupleAccess (e1, n, s, typs) -> begin
let utyps =
List.map
(fun typ -> UnionFind.make (Marked.map_under_mark ast_to_typ typ))
typs
List.map ast_to_typ typs
in
let e1 =
let+ e1 =
typecheck_expr_top_down ctx env e1 (unionfind_make (TTuple (utyps, s)))
in
match List.nth_opt utyps n with
@ -276,8 +272,7 @@ let rec typecheck_expr_bottom_up: 'm .
end
| A.EInj (e1, n, e_name, ts) ->
let ts' =
List.map
(fun t -> UnionFind.make (Marked.map_under_mark ast_to_typ t))
List.map ast_to_typ
ts
in
let ts_n =
@ -289,7 +284,7 @@ let rec typecheck_expr_bottom_up: 'm .
has %d"
n (List.length ts')
in
let e1' = typecheck_expr_top_down ctx env e1 ts_n in
let+ e1' = typecheck_expr_top_down ctx env e1 ts_n in
mark_with_uf (A.EInj (e1', n, e_name, ts)) (TEnum (ts', e_name))
| A.EMatch (e1, es, e_name) ->
let enum_cases =
@ -298,30 +293,28 @@ let rec typecheck_expr_bottom_up: 'm .
es
in
let t_e1 = UnionFind.make (add_pos e1 (TEnum (enum_cases, e_name))) in
let e1' = typecheck_expr_top_down ctx env e1 t_e1 in
let t_ret =
unionfind_make ~pos:e (TAny (Any.fresh ()))
in
let es' = List.map2 (fun es' enum_t ->
let+ e1' = typecheck_expr_top_down ctx env e1 t_e1
and+ es' = bmap2 (fun es' enum_t ->
typecheck_expr_top_down ctx env es'
(unionfind_make ~pos:es' (TArrow (enum_t, t_ret))))
es enum_cases
in
mark (EMatch (e1', es', e_name)) t_ret
| A.EAbs (binder, taus) ->
let xs, body = Bindlib.unmbind binder in
if Array.length xs <> List.length taus then
if Bindlib.mbinder_arity binder <> List.length taus then
Errors.raise_spanned_error (A.pos e)
"function has %d variables but was supplied %d types"
(Array.length xs) (List.length taus)
(Bindlib.mbinder_arity binder) (List.length taus)
else
let xs, body = Bindlib.unmbind binder in
let xs' = Array.map translate_var xs in
let xstaus =
List.map2
(fun x tau ->
x,
UnionFind.make
(Marked.map_under_mark ast_to_typ tau))
x, ast_to_typ tau)
(Array.to_list xs) taus
in
let env =
@ -333,29 +326,31 @@ let rec typecheck_expr_bottom_up: 'm .
(fun (_, t_arg) acc ->
unionfind_make (TArrow (t_arg, acc)))
xstaus
(ty body')
(box_ty body')
in
(* TODO: check this use of binders *)
let binder' = Bindlib.unbox (Bindlib.bind_mvar xs' (Bindlib.box body')) in
let+ binder' = Bindlib.bind_mvar xs' body' in
mark (EAbs (binder', taus)) t_func
| A.EApp (e1, args) ->
let args' = List.map (typecheck_expr_bottom_up ctx env) args in
let args' = bmap (typecheck_expr_bottom_up ctx env) args in
let t_ret = unionfind_make (TAny (Any.fresh ())) in
let t_func =
List.fold_right
(fun arg acc -> unionfind_make (TArrow (ty arg, acc)))
args' t_ret
(fun ty_arg acc -> unionfind_make (TArrow (ty_arg, acc)))
(Bindlib.unbox (Bindlib.box_apply (List.map ty) args')) t_ret
in
let e1' = typecheck_expr_top_down ctx env e1 t_func in
let+ e1' = typecheck_expr_top_down ctx env e1 t_func
and+ args' = args' in
mark (EApp (e1', args')) t_ret
| A.EOp op as e1 -> mark e1 (op_type (Marked.mark pos_e op))
| A.EOp op as e1 -> Bindlib.box @@ mark e1 (op_type (Marked.mark pos_e op))
| A.EDefault (excepts, just, cons) ->
let just' = typecheck_expr_top_down ctx env just
(unionfind_make ~pos:just (TLit TBool)) in
let cons' = typecheck_expr_bottom_up ctx env cons in
let tau = ty cons' in
let excepts' =
List.map
let tau = box_ty cons' in
let+ just' = just'
and+ cons' = cons'
and+ excepts' =
bmap
(fun except -> typecheck_expr_top_down ctx env except tau)
excepts
in
@ -364,24 +359,25 @@ let rec typecheck_expr_bottom_up: 'm .
let cond' = typecheck_expr_top_down ctx env cond
(unionfind_make ~pos:cond (TLit TBool)) in
let et' = typecheck_expr_bottom_up ctx env et in
let tau = ty et' in
let ef' = typecheck_expr_top_down ctx env ef tau in
let tau = box_ty et' in
let+ cond' = cond' and+ et' = et'
and+ ef' = typecheck_expr_top_down ctx env ef tau in
mark (A.EIfThenElse (cond', et', ef')) tau
| A.EAssert e1 ->
let e1' =
let+ e1' =
typecheck_expr_top_down ctx env e1
(unionfind_make ~pos:e1 (TLit TBool)) in
mark_with_uf (A.EAssert e1') ~pos:e1 (TLit TUnit)
| A.ErrorOnEmpty e1 ->
let e1' = typecheck_expr_bottom_up ctx env e1 in
let+ e1' = typecheck_expr_bottom_up ctx env e1 in
mark (A.ErrorOnEmpty e1') (ty e1')
| A.EArray es ->
let cell_type = unionfind_make (TAny (Any.fresh ())) in
let es' =
List.map
let+ es' =
bmap
(fun e1 ->
let e1' = typecheck_expr_bottom_up ctx env e1 in
unify ctx cell_type (ty e1');
unify ctx cell_type (box_ty e1');
e1')
es
in
@ -396,14 +392,11 @@ let rec typecheck_expr_bottom_up: 'm .
(** Checks whether the expression can be typed with the provided type *)
and typecheck_expr_top_down
: 'm .
Ast.decl_ctx ->
env ->
'm A.marked_expr -> typ Marked.pos UnionFind.elem -> A.typed_expr = fun (ctx : Ast.decl_ctx)
(ctx : Ast.decl_ctx)
(env : env)
(e : 'm A.marked_expr)
(tau : typ Marked.pos UnionFind.elem) :
A.typed_expr ->
A.typed_expr Bindlib.box =
(* Cli.debug_print (Format.asprintf "Typechecking %a : %a" (Print.format_expr
ctx) e (format_typ ctx) tau); *)
try
@ -416,50 +409,48 @@ and typecheck_expr_top_down
Marked.mark (A.Typed { ty = tau'; pos = pos_e }) e
in
let unionfind_make ?(pos=e) t = UnionFind.make (add_pos pos t) in
let unionfind_of_typ typ =
UnionFind.make (Marked.map_under_mark ast_to_typ typ)
in
match Marked.unmark e with
| A.EVar v -> begin
match A.VarMap.find_opt (A.Var.t v) env with
| Some tau' -> unify_and_mark (A.EVar (translate_var v)) tau'
| Some tau' -> let+ v' = Bindlib.box_var (translate_var v) in
unify_and_mark v' tau'
| None ->
Errors.raise_spanned_error (A.pos e)
"Variable not found in the current context"
end
| A.ELit (LBool _) as e1 -> unify_and_mark e1 (unionfind_make (TLit TBool))
| A.ELit (LInt _) as e1 -> unify_and_mark e1 (unionfind_make (TLit TInt))
| A.ELit (LRat _) as e1 -> unify_and_mark e1 (unionfind_make (TLit TRat))
| A.ELit (LBool _) as e1 -> Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TBool))
| A.ELit (LInt _) as e1 -> Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TInt))
| A.ELit (LRat _) as e1 -> Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TRat))
| A.ELit (LMoney _) as e1 ->
unify_and_mark e1 (unionfind_make (TLit TMoney))
| A.ELit (LDate _) as e1 -> unify_and_mark e1 (unionfind_make (TLit TDate))
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TMoney))
| A.ELit (LDate _) as e1 -> Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TDate))
| A.ELit (LDuration _) as e1 ->
unify_and_mark e1 (unionfind_make (TLit TDuration))
| A.ELit LUnit as e1 -> unify_and_mark e1 (unionfind_make (TLit TUnit))
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TDuration))
| A.ELit LUnit as e1 -> Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TUnit))
| A.ELit LEmptyError as e1 ->
unify_and_mark e1 (unionfind_make (TAny (Any.fresh ())))
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TAny (Any.fresh ())))
| A.ETuple (es, s) ->
let es' = List.map (typecheck_expr_bottom_up ctx env) es in
let+ es' = bmap (typecheck_expr_bottom_up ctx env) es in
unify_and_mark (A.ETuple (es', s))
(unionfind_make (TTuple (List.map ty es', s)))
| A.ETupleAccess (e1, n, s, typs) -> begin
let typs' = List.map unionfind_of_typ typs in
let e1' =
let typs' = List.map ast_to_typ typs in
let+ e1' =
typecheck_expr_top_down ctx env e1
(unionfind_make (TTuple (typs', s)))
in
match List.nth_opt typs' n with
| Some t1n ->
unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n
| None ->
match List.nth_opt typs' n with
| Some t1n ->
unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n
| None ->
Errors.raise_spanned_error (Ast.pos e1)
"Expression should have a tuple type with at least %d elements but \
only has %d"
n (List.length typs)
only has %d"
n (List.length typs)
end
| A.EInj (e1, n, e_name, ts) ->
let ts' =
List.map unionfind_of_typ ts
List.map ast_to_typ ts
in
let ts_n =
match List.nth_opt ts' n with
@ -470,7 +461,7 @@ and typecheck_expr_top_down
has %d"
n (List.length ts)
in
let e1' = typecheck_expr_top_down ctx env e1 ts_n in
let+ e1' = typecheck_expr_top_down ctx env e1 ts_n in
unify_and_mark (A.EInj (e1', n, e_name, ts))
(unionfind_make (TEnum (ts', e_name)))
| A.EMatch (e1, es, e_name) ->
@ -483,46 +474,45 @@ and typecheck_expr_top_down
(unionfind_make ~pos:e1 (TEnum (enum_cases, e_name)))
in
let t_ret = unionfind_make ~pos:e (TAny (Any.fresh ())) in
let es' =
List.map2 (fun es' enum_t ->
let+ e1' = e1'
and+ es' =
bmap2 (fun es' enum_t ->
typecheck_expr_top_down ctx env es'
(unionfind_make ~pos:es' (TArrow (enum_t, t_ret))))
es enum_cases
in
unify_and_mark (EMatch (e1', es', e_name)) t_ret
| A.EAbs (binder, t_args) ->
(* Bindlib.box binder |> Bindlib.mbind_apply *)
let xs, body = Bindlib.unmbind binder in
if Array.length xs <> List.length t_args then
if Bindlib.mbinder_arity binder <> List.length t_args then
Errors.raise_spanned_error (A.pos e)
"function has %d variables but was supplied %d types"
(Array.length xs) (List.length t_args)
(Bindlib.mbinder_arity binder) (List.length t_args)
else
let xs' = Array.map translate_var xs in
let xstaus =
List.map2
(fun x t_arg ->
x, UnionFind.make (Marked.map_under_mark ast_to_typ t_arg))
(Array.to_list xs) t_args
in
let env =
List.fold_left
(fun env (x, t_arg) -> A.VarMap.add (A.Var.t x) t_arg env)
env xstaus
in
let body' = typecheck_expr_bottom_up ctx env body in
let t_func =
List.fold_right
(fun (_, t_arg) acc ->
unionfind_make (TArrow (t_arg, acc)))
xstaus (ty body')
in
(* TODO: check this use of binders *)
let binder' = Bindlib.unbox (Bindlib.bind_mvar xs' (Bindlib.box body')) in
unify_and_mark (EAbs (binder', t_args)) t_func
let xs, body = Bindlib.unmbind binder in
let xs' = Array.map translate_var xs in
let xstaus =
List.map2
(fun x t_arg ->
x, ast_to_typ t_arg)
(Array.to_list xs) t_args
in
let env =
List.fold_left
(fun env (x, t_arg) -> A.VarMap.add (A.Var.t x) t_arg env)
env xstaus
in
let body' = typecheck_expr_bottom_up ctx env body in
let t_func =
List.fold_right
(fun (_, t_arg) acc ->
unionfind_make (TArrow (t_arg, acc)))
xstaus (box_ty body')
in
let+ binder' = Bindlib.bind_mvar xs' body' in
unify_and_mark (EAbs (binder', t_args)) t_func
| A.EApp (e1, args) ->
let args' = List.map (typecheck_expr_bottom_up ctx env) args in
let e1' = typecheck_expr_bottom_up ctx env e1 in
let+ args' = bmap (typecheck_expr_bottom_up ctx env) args
and+ e1' = typecheck_expr_bottom_up ctx env e1 in
let t_func =
List.fold_right
(fun arg acc -> unionfind_make (TArrow (ty arg, acc)))
@ -531,35 +521,35 @@ and typecheck_expr_top_down
unify_and_mark (EApp (e1', args')) t_func
| A.EOp op as e1 ->
let op_typ = op_type (add_pos e op) in
unify_and_mark e1 op_typ
Bindlib.box (unify_and_mark e1 op_typ)
| A.EDefault (excepts, just, cons) ->
let just' = typecheck_expr_top_down ctx env just (unionfind_make ~pos:just (TLit TBool)) in
let cons' = typecheck_expr_top_down ctx env cons tau in
let excepts' =
List.map (fun except -> typecheck_expr_top_down ctx env except tau)
let+ just' = typecheck_expr_top_down ctx env just (unionfind_make ~pos:just (TLit TBool))
and+ cons' = typecheck_expr_top_down ctx env cons tau
and+ excepts' =
bmap (fun except -> typecheck_expr_top_down ctx env except tau)
excepts
in
mark (A.EDefault (excepts', just', cons'))
| A.EIfThenElse (cond, et, ef) ->
let cond' = typecheck_expr_top_down ctx env cond
(unionfind_make ~pos:cond (TLit TBool)) in
let et' = typecheck_expr_top_down ctx env et tau in
let ef' = typecheck_expr_top_down ctx env ef tau in
let+ cond' = typecheck_expr_top_down ctx env cond
(unionfind_make ~pos:cond (TLit TBool))
and+ et' = typecheck_expr_top_down ctx env et tau
and+ ef' = typecheck_expr_top_down ctx env ef tau in
mark (A.EIfThenElse (cond', et', ef'))
| A.EAssert e1 ->
let e1' = typecheck_expr_top_down ctx env e1
let+ e1' = typecheck_expr_top_down ctx env e1
(unionfind_make ~pos:e1 (TLit TBool)) in
unify_and_mark (EAssert e1') (unionfind_make ~pos:e1 (TLit TUnit))
| A.ErrorOnEmpty e1 ->
let e1' = typecheck_expr_top_down ctx env e1 tau in
let+ e1' = typecheck_expr_top_down ctx env e1 tau in
mark (A.ErrorOnEmpty e1')
| A.EArray es ->
let cell_type = unionfind_make (TAny (Any.fresh ())) in
let es' =
List.map
let+ es' =
bmap
(fun e1 ->
let e1' = typecheck_expr_bottom_up ctx env e1 in
unify ctx cell_type (ty e1');
unify ctx cell_type (box_ty e1');
e1')
es
in
@ -577,7 +567,7 @@ and typecheck_expr_top_down
(* Infer the type of an expression *)
let infer_types (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : Ast.typed Ast.marked_expr
=
typecheck_expr_bottom_up ctx A.VarMap.empty e
Bindlib.unbox @@ typecheck_expr_bottom_up ctx A.VarMap.empty e
let infer_type (type m) ctx (e: m A.marked_expr) =
match Marked.get_mark e with
@ -591,14 +581,13 @@ let check_type
(tau : A.typ Marked.pos) =
(* todo: consider using the already inferred type if ['m] = [typed] *)
ignore @@
typecheck_expr_top_down ctx A.VarMap.empty e
(UnionFind.make (Marked.map_under_mark ast_to_typ tau))
typecheck_expr_top_down ctx A.VarMap.empty e (ast_to_typ tau)
let infer_types_program prg =
let scopes =
Bindlib.unbox @@
A.map_exprs_in_scopes
~f:(fun e -> Bindlib.box (typecheck_expr_bottom_up prg.A.decl_ctx A.VarMap.empty e))
~f:(typecheck_expr_bottom_up prg.A.decl_ctx A.VarMap.empty)
~varf:translate_var
prg.A.scopes
in