mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
typing.ml: fix use of bindlib
This commit is contained in:
parent
67179a793c
commit
6cc2e9a07b
@ -138,6 +138,18 @@ module Infer = struct
|
||||
| TArrow (t1, t2) -> TArrow (typ_to_ast t1, typ_to_ast t2), pos
|
||||
| TAny _ -> TAny, pos
|
||||
| TArray t1 -> TArray (typ_to_ast t1), pos
|
||||
|
||||
let rec ast_to_typ (ty : marked_typ) : unionfind_typ =
|
||||
let ty' =
|
||||
match Marked.unmark ty with
|
||||
| TLit l -> TLit l
|
||||
| TArrow (t1, t2) -> TArrow (ast_to_typ t1, ast_to_typ t2)
|
||||
| TTuple (ts, s) -> TTuple (List.map (fun t -> ast_to_typ t) ts, s)
|
||||
| TEnum (ts, e) -> TEnum (List.map (fun t -> ast_to_typ t) ts, e)
|
||||
| TArray t -> TArray (ast_to_typ t)
|
||||
| TAny -> TAny (Any.fresh ())
|
||||
in
|
||||
UnionFind.make (Marked.same_mark_as ty' ty)
|
||||
end
|
||||
|
||||
type untyped = { pos : Pos.t } [@@ocaml.unboxed]
|
||||
@ -483,16 +495,6 @@ let (make_abs : ('m expr, 'm) make_abs_sig) =
|
||||
fun xs e taus mark ->
|
||||
Bindlib.box_apply (fun b -> EAbs (b, taus), mark) (Bindlib.bind_mvar xs e)
|
||||
|
||||
let empty_thunked_term : untyped marked_expr =
|
||||
let silent = new_var "_" in
|
||||
Bindlib.unbox
|
||||
(make_abs [| silent |]
|
||||
(Bindlib.box
|
||||
(ELit LEmptyError, Untyped { pos = Pos.no_pos }
|
||||
: (untyped expr, untyped) marked))
|
||||
[TLit TUnit, Pos.no_pos]
|
||||
(Untyped { pos = Pos.no_pos }))
|
||||
|
||||
let make_app :
|
||||
'm marked_expr Bindlib.box ->
|
||||
'm marked_expr Bindlib.box list ->
|
||||
@ -528,6 +530,37 @@ let map_mark2
|
||||
| Untyped m1, Untyped m2 -> Untyped { pos = pos_f m1.pos m2.pos }
|
||||
| Typed m1, Typed m2 -> Typed { pos = pos_f m1.pos m2.pos; ty = ty_f m1 m2 }
|
||||
|
||||
let fold_marks
|
||||
(type m)
|
||||
(pos_f : Pos.t list -> Pos.t)
|
||||
(ty_f : typed list -> Infer.unionfind_typ)
|
||||
(ms : m mark list) : m mark =
|
||||
match ms with
|
||||
| [] -> invalid_arg "Dcalc.Ast.fold_mark"
|
||||
| Untyped _ :: _ as ms ->
|
||||
Untyped { pos = pos_f (List.map (function Untyped { pos } -> pos) ms) }
|
||||
| Typed _ :: _ ->
|
||||
Typed
|
||||
{
|
||||
pos = pos_f (List.map (function Typed { pos; _ } -> pos) ms);
|
||||
ty = ty_f (List.map (function Typed m -> m) ms);
|
||||
}
|
||||
|
||||
let empty_thunked_term mark : 'm marked_expr =
|
||||
let silent = new_var "_" in
|
||||
let pos = mark_pos mark in
|
||||
Bindlib.unbox
|
||||
(make_abs [| silent |]
|
||||
(Bindlib.box (ELit LEmptyError, mark))
|
||||
[TLit TUnit, mark_pos mark]
|
||||
(map_mark
|
||||
(fun pos -> pos)
|
||||
(fun ty ->
|
||||
UnionFind.make
|
||||
Infer.(
|
||||
TArrow (UnionFind.make (TLit TUnit, pos), ty), mark_pos mark))
|
||||
mark))
|
||||
|
||||
let (make_let_in : ('m expr, 'm) make_let_in_sig) =
|
||||
fun x tau e1 e2 pos ->
|
||||
let m_e1 = Marked.get_mark (Bindlib.unbox e1) in
|
||||
|
@ -122,6 +122,8 @@ module Infer: sig
|
||||
|
||||
val typ_to_ast : unionfind_typ -> marked_typ
|
||||
|
||||
val ast_to_typ : marked_typ -> unionfind_typ
|
||||
|
||||
end
|
||||
|
||||
type untyped = { pos : Pos.t } [@@unboxed]
|
||||
@ -232,6 +234,7 @@ val ty: ('a, typed) marked -> typ
|
||||
val with_ty: Infer.unionfind_typ -> ('a, 'm) marked -> ('a, typed) marked
|
||||
val map_mark: (Pos.t -> Pos.t) -> (Infer.unionfind_typ -> Infer.unionfind_typ) -> 'm mark -> 'm mark
|
||||
val map_mark2: (Pos.t -> Pos.t -> Pos.t) -> (typed -> typed -> Infer.unionfind_typ) -> 'm mark -> 'm mark -> 'm mark
|
||||
val fold_marks: (Pos.t list -> Pos.t) -> (typed list -> Infer.unionfind_typ) -> 'm mark list -> 'm mark
|
||||
val get_scope_body_mark: ('expr, 'm) scope_body -> 'm mark
|
||||
|
||||
(** {2 Boxed constructors} *)
|
||||
@ -448,7 +451,7 @@ val make_let_in : ('m expr, 'm) make_let_in_sig
|
||||
|
||||
(**{2 Other}*)
|
||||
|
||||
val empty_thunked_term : untyped marked_expr
|
||||
val empty_thunked_term : 'm mark -> 'm marked_expr
|
||||
val is_value : 'm marked_expr -> bool
|
||||
|
||||
val equal_exprs : 'm marked_expr -> 'm marked_expr -> bool
|
||||
|
@ -185,28 +185,6 @@ let op_type (op : A.operator Marked.pos) : typ Marked.pos UnionFind.elem =
|
||||
| Binop (Mult KDate) | Binop (Div KDate) | Unop (Minus KDate) ->
|
||||
Errors.raise_spanned_error pos "This operator is not available!"
|
||||
|
||||
let rec ast_to_typ (ty : A.typ) : typ =
|
||||
match ty with
|
||||
| A.TLit l -> TLit l
|
||||
| A.TArrow (t1, t2) ->
|
||||
TArrow
|
||||
( UnionFind.make (Marked.map_under_mark ast_to_typ t1),
|
||||
UnionFind.make (Marked.map_under_mark ast_to_typ t2) )
|
||||
| A.TTuple (ts, s) ->
|
||||
TTuple
|
||||
( List.map
|
||||
(fun t -> UnionFind.make (Marked.map_under_mark ast_to_typ t))
|
||||
ts,
|
||||
s )
|
||||
| A.TEnum (ts, e) ->
|
||||
TEnum
|
||||
( List.map
|
||||
(fun t -> UnionFind.make (Marked.map_under_mark ast_to_typ t))
|
||||
ts,
|
||||
e )
|
||||
| A.TArray t -> TArray (UnionFind.make (Marked.map_under_mark ast_to_typ t))
|
||||
| A.TAny -> TAny (Any.fresh ())
|
||||
|
||||
(** {1 Double-directed typing} *)
|
||||
|
||||
type env = typ Marked.pos UnionFind.elem A.VarMap.t
|
||||
@ -215,19 +193,34 @@ let add_pos e ty = Marked.mark (A.pos e) ty
|
||||
let ty (_, A.Typed { ty; _ }) = ty
|
||||
|
||||
(** used to convert an [untyped expr var] into a [typed expr var] *)
|
||||
let translate_var v = Bindlib.copy_var v (fun x -> A.EVar x) (Bindlib.name_of v)
|
||||
let translate_var: 'm1 A.var -> 'm2 A.var =
|
||||
fun v -> Bindlib.copy_var v (fun x -> A.EVar x) (Bindlib.name_of v)
|
||||
|
||||
let (let+) x f = Bindlib.box_apply f x
|
||||
let (and+) x1 x2 = Bindlib.box_pair x1 x2
|
||||
|
||||
let bmap f es =
|
||||
List.fold_right
|
||||
(fun e acc -> let+ e' = f e and+ acc = acc in e' :: acc)
|
||||
es
|
||||
(Bindlib.box [])
|
||||
|
||||
let bmap2 f es xs =
|
||||
List.fold_right2
|
||||
(fun e x acc ->
|
||||
let+ e' = f e x
|
||||
and+ acc = acc
|
||||
in e' :: acc)
|
||||
es xs
|
||||
(Bindlib.box [])
|
||||
|
||||
let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e)
|
||||
|
||||
(** Infers the most permissive type from an expression *)
|
||||
let rec typecheck_expr_bottom_up: 'm .
|
||||
Ast.decl_ctx ->
|
||||
env ->
|
||||
'm A.marked_expr -> A.typed_expr = fun
|
||||
let rec typecheck_expr_bottom_up
|
||||
(ctx : Ast.decl_ctx)
|
||||
(env : env)
|
||||
(e : 'm A.marked_expr) : A.typed_expr ->
|
||||
(* (ctx : Ast.decl_ctx)
|
||||
* (env : env)
|
||||
* (e : 'm A.marked_expr) : A.typed_expr = *)
|
||||
(e : 'm A.marked_expr) : A.typed_expr Bindlib.box =
|
||||
(* Cli.debug_print (Format.asprintf "Looking for type of %a"
|
||||
(Print.format_expr ctx) e); *)
|
||||
try
|
||||
@ -241,29 +234,32 @@ let rec typecheck_expr_bottom_up: 'm .
|
||||
| A.EVar v -> begin
|
||||
match A.VarMap.find_opt (A.Var.t v) env with
|
||||
| Some t ->
|
||||
mark (EVar (translate_var v)) t
|
||||
let+ v' = Bindlib.box_var (translate_var v) in
|
||||
mark v' t
|
||||
| None ->
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
"Variable not found in the current context"
|
||||
end
|
||||
| A.ELit (LBool _) as e1 -> mark_with_uf e1 (TLit TBool)
|
||||
| A.ELit (LInt _) as e1 -> mark_with_uf e1 (TLit TInt)
|
||||
| A.ELit (LRat _) as e1 -> mark_with_uf e1 (TLit TRat)
|
||||
| A.ELit (LMoney _) as e1 -> mark_with_uf e1 (TLit TMoney)
|
||||
| A.ELit (LDate _) as e1 -> mark_with_uf e1 (TLit TDate)
|
||||
| A.ELit (LDuration _) as e1 -> mark_with_uf e1 (TLit TDuration)
|
||||
| A.ELit LUnit as e1 -> mark_with_uf e1 (TLit TUnit)
|
||||
| A.ELit LEmptyError as e1 -> mark_with_uf e1 (TAny (Any.fresh ()))
|
||||
| A.ELit (LBool _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TBool)
|
||||
| A.ELit (LInt _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TInt)
|
||||
| A.ELit (LRat _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TRat)
|
||||
| A.ELit (LMoney _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TMoney)
|
||||
| A.ELit (LDate _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TDate)
|
||||
| A.ELit (LDuration _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TDuration)
|
||||
| A.ELit LUnit as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TUnit)
|
||||
| A.ELit LEmptyError as e1 -> Bindlib.box @@ mark_with_uf e1 (TAny (Any.fresh ()))
|
||||
| A.ETuple (es, s) ->
|
||||
let es = List.map (typecheck_expr_bottom_up ctx env) es in
|
||||
let+ es =
|
||||
bmap
|
||||
(typecheck_expr_bottom_up ctx env)
|
||||
es
|
||||
in
|
||||
mark_with_uf (ETuple (es, s)) (TTuple (List.map ty es, s))
|
||||
| A.ETupleAccess (e1, n, s, typs) -> begin
|
||||
let utyps =
|
||||
List.map
|
||||
(fun typ -> UnionFind.make (Marked.map_under_mark ast_to_typ typ))
|
||||
typs
|
||||
List.map ast_to_typ typs
|
||||
in
|
||||
let e1 =
|
||||
let+ e1 =
|
||||
typecheck_expr_top_down ctx env e1 (unionfind_make (TTuple (utyps, s)))
|
||||
in
|
||||
match List.nth_opt utyps n with
|
||||
@ -276,8 +272,7 @@ let rec typecheck_expr_bottom_up: 'm .
|
||||
end
|
||||
| A.EInj (e1, n, e_name, ts) ->
|
||||
let ts' =
|
||||
List.map
|
||||
(fun t -> UnionFind.make (Marked.map_under_mark ast_to_typ t))
|
||||
List.map ast_to_typ
|
||||
ts
|
||||
in
|
||||
let ts_n =
|
||||
@ -289,7 +284,7 @@ let rec typecheck_expr_bottom_up: 'm .
|
||||
has %d"
|
||||
n (List.length ts')
|
||||
in
|
||||
let e1' = typecheck_expr_top_down ctx env e1 ts_n in
|
||||
let+ e1' = typecheck_expr_top_down ctx env e1 ts_n in
|
||||
mark_with_uf (A.EInj (e1', n, e_name, ts)) (TEnum (ts', e_name))
|
||||
| A.EMatch (e1, es, e_name) ->
|
||||
let enum_cases =
|
||||
@ -298,30 +293,28 @@ let rec typecheck_expr_bottom_up: 'm .
|
||||
es
|
||||
in
|
||||
let t_e1 = UnionFind.make (add_pos e1 (TEnum (enum_cases, e_name))) in
|
||||
let e1' = typecheck_expr_top_down ctx env e1 t_e1 in
|
||||
let t_ret =
|
||||
unionfind_make ~pos:e (TAny (Any.fresh ()))
|
||||
in
|
||||
let es' = List.map2 (fun es' enum_t ->
|
||||
let+ e1' = typecheck_expr_top_down ctx env e1 t_e1
|
||||
and+ es' = bmap2 (fun es' enum_t ->
|
||||
typecheck_expr_top_down ctx env es'
|
||||
(unionfind_make ~pos:es' (TArrow (enum_t, t_ret))))
|
||||
es enum_cases
|
||||
in
|
||||
mark (EMatch (e1', es', e_name)) t_ret
|
||||
| A.EAbs (binder, taus) ->
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
if Array.length xs <> List.length taus then
|
||||
if Bindlib.mbinder_arity binder <> List.length taus then
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
"function has %d variables but was supplied %d types"
|
||||
(Array.length xs) (List.length taus)
|
||||
(Bindlib.mbinder_arity binder) (List.length taus)
|
||||
else
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs' = Array.map translate_var xs in
|
||||
let xstaus =
|
||||
List.map2
|
||||
(fun x tau ->
|
||||
x,
|
||||
UnionFind.make
|
||||
(Marked.map_under_mark ast_to_typ tau))
|
||||
x, ast_to_typ tau)
|
||||
(Array.to_list xs) taus
|
||||
in
|
||||
let env =
|
||||
@ -333,29 +326,31 @@ let rec typecheck_expr_bottom_up: 'm .
|
||||
(fun (_, t_arg) acc ->
|
||||
unionfind_make (TArrow (t_arg, acc)))
|
||||
xstaus
|
||||
(ty body')
|
||||
(box_ty body')
|
||||
in
|
||||
(* TODO: check this use of binders *)
|
||||
let binder' = Bindlib.unbox (Bindlib.bind_mvar xs' (Bindlib.box body')) in
|
||||
let+ binder' = Bindlib.bind_mvar xs' body' in
|
||||
mark (EAbs (binder', taus)) t_func
|
||||
| A.EApp (e1, args) ->
|
||||
let args' = List.map (typecheck_expr_bottom_up ctx env) args in
|
||||
let args' = bmap (typecheck_expr_bottom_up ctx env) args in
|
||||
let t_ret = unionfind_make (TAny (Any.fresh ())) in
|
||||
let t_func =
|
||||
List.fold_right
|
||||
(fun arg acc -> unionfind_make (TArrow (ty arg, acc)))
|
||||
args' t_ret
|
||||
(fun ty_arg acc -> unionfind_make (TArrow (ty_arg, acc)))
|
||||
(Bindlib.unbox (Bindlib.box_apply (List.map ty) args')) t_ret
|
||||
in
|
||||
let e1' = typecheck_expr_top_down ctx env e1 t_func in
|
||||
let+ e1' = typecheck_expr_top_down ctx env e1 t_func
|
||||
and+ args' = args' in
|
||||
mark (EApp (e1', args')) t_ret
|
||||
| A.EOp op as e1 -> mark e1 (op_type (Marked.mark pos_e op))
|
||||
| A.EOp op as e1 -> Bindlib.box @@ mark e1 (op_type (Marked.mark pos_e op))
|
||||
| A.EDefault (excepts, just, cons) ->
|
||||
let just' = typecheck_expr_top_down ctx env just
|
||||
(unionfind_make ~pos:just (TLit TBool)) in
|
||||
let cons' = typecheck_expr_bottom_up ctx env cons in
|
||||
let tau = ty cons' in
|
||||
let excepts' =
|
||||
List.map
|
||||
let tau = box_ty cons' in
|
||||
let+ just' = just'
|
||||
and+ cons' = cons'
|
||||
and+ excepts' =
|
||||
bmap
|
||||
(fun except -> typecheck_expr_top_down ctx env except tau)
|
||||
excepts
|
||||
in
|
||||
@ -364,24 +359,25 @@ let rec typecheck_expr_bottom_up: 'm .
|
||||
let cond' = typecheck_expr_top_down ctx env cond
|
||||
(unionfind_make ~pos:cond (TLit TBool)) in
|
||||
let et' = typecheck_expr_bottom_up ctx env et in
|
||||
let tau = ty et' in
|
||||
let ef' = typecheck_expr_top_down ctx env ef tau in
|
||||
let tau = box_ty et' in
|
||||
let+ cond' = cond' and+ et' = et'
|
||||
and+ ef' = typecheck_expr_top_down ctx env ef tau in
|
||||
mark (A.EIfThenElse (cond', et', ef')) tau
|
||||
| A.EAssert e1 ->
|
||||
let e1' =
|
||||
let+ e1' =
|
||||
typecheck_expr_top_down ctx env e1
|
||||
(unionfind_make ~pos:e1 (TLit TBool)) in
|
||||
mark_with_uf (A.EAssert e1') ~pos:e1 (TLit TUnit)
|
||||
| A.ErrorOnEmpty e1 ->
|
||||
let e1' = typecheck_expr_bottom_up ctx env e1 in
|
||||
let+ e1' = typecheck_expr_bottom_up ctx env e1 in
|
||||
mark (A.ErrorOnEmpty e1') (ty e1')
|
||||
| A.EArray es ->
|
||||
let cell_type = unionfind_make (TAny (Any.fresh ())) in
|
||||
let es' =
|
||||
List.map
|
||||
let+ es' =
|
||||
bmap
|
||||
(fun e1 ->
|
||||
let e1' = typecheck_expr_bottom_up ctx env e1 in
|
||||
unify ctx cell_type (ty e1');
|
||||
unify ctx cell_type (box_ty e1');
|
||||
e1')
|
||||
es
|
||||
in
|
||||
@ -396,14 +392,11 @@ let rec typecheck_expr_bottom_up: 'm .
|
||||
|
||||
(** Checks whether the expression can be typed with the provided type *)
|
||||
and typecheck_expr_top_down
|
||||
: 'm .
|
||||
Ast.decl_ctx ->
|
||||
env ->
|
||||
'm A.marked_expr -> typ Marked.pos UnionFind.elem -> A.typed_expr = fun (ctx : Ast.decl_ctx)
|
||||
(ctx : Ast.decl_ctx)
|
||||
(env : env)
|
||||
(e : 'm A.marked_expr)
|
||||
(tau : typ Marked.pos UnionFind.elem) :
|
||||
A.typed_expr ->
|
||||
A.typed_expr Bindlib.box =
|
||||
(* Cli.debug_print (Format.asprintf "Typechecking %a : %a" (Print.format_expr
|
||||
ctx) e (format_typ ctx) tau); *)
|
||||
try
|
||||
@ -416,50 +409,48 @@ and typecheck_expr_top_down
|
||||
Marked.mark (A.Typed { ty = tau'; pos = pos_e }) e
|
||||
in
|
||||
let unionfind_make ?(pos=e) t = UnionFind.make (add_pos pos t) in
|
||||
let unionfind_of_typ typ =
|
||||
UnionFind.make (Marked.map_under_mark ast_to_typ typ)
|
||||
in
|
||||
match Marked.unmark e with
|
||||
| A.EVar v -> begin
|
||||
match A.VarMap.find_opt (A.Var.t v) env with
|
||||
| Some tau' -> unify_and_mark (A.EVar (translate_var v)) tau'
|
||||
| Some tau' -> let+ v' = Bindlib.box_var (translate_var v) in
|
||||
unify_and_mark v' tau'
|
||||
| None ->
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
"Variable not found in the current context"
|
||||
end
|
||||
| A.ELit (LBool _) as e1 -> unify_and_mark e1 (unionfind_make (TLit TBool))
|
||||
| A.ELit (LInt _) as e1 -> unify_and_mark e1 (unionfind_make (TLit TInt))
|
||||
| A.ELit (LRat _) as e1 -> unify_and_mark e1 (unionfind_make (TLit TRat))
|
||||
| A.ELit (LBool _) as e1 -> Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TBool))
|
||||
| A.ELit (LInt _) as e1 -> Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TInt))
|
||||
| A.ELit (LRat _) as e1 -> Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TRat))
|
||||
| A.ELit (LMoney _) as e1 ->
|
||||
unify_and_mark e1 (unionfind_make (TLit TMoney))
|
||||
| A.ELit (LDate _) as e1 -> unify_and_mark e1 (unionfind_make (TLit TDate))
|
||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TMoney))
|
||||
| A.ELit (LDate _) as e1 -> Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TDate))
|
||||
| A.ELit (LDuration _) as e1 ->
|
||||
unify_and_mark e1 (unionfind_make (TLit TDuration))
|
||||
| A.ELit LUnit as e1 -> unify_and_mark e1 (unionfind_make (TLit TUnit))
|
||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TDuration))
|
||||
| A.ELit LUnit as e1 -> Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TUnit))
|
||||
| A.ELit LEmptyError as e1 ->
|
||||
unify_and_mark e1 (unionfind_make (TAny (Any.fresh ())))
|
||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TAny (Any.fresh ())))
|
||||
| A.ETuple (es, s) ->
|
||||
let es' = List.map (typecheck_expr_bottom_up ctx env) es in
|
||||
let+ es' = bmap (typecheck_expr_bottom_up ctx env) es in
|
||||
unify_and_mark (A.ETuple (es', s))
|
||||
(unionfind_make (TTuple (List.map ty es', s)))
|
||||
| A.ETupleAccess (e1, n, s, typs) -> begin
|
||||
let typs' = List.map unionfind_of_typ typs in
|
||||
let e1' =
|
||||
let typs' = List.map ast_to_typ typs in
|
||||
let+ e1' =
|
||||
typecheck_expr_top_down ctx env e1
|
||||
(unionfind_make (TTuple (typs', s)))
|
||||
in
|
||||
match List.nth_opt typs' n with
|
||||
| Some t1n ->
|
||||
unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n
|
||||
| None ->
|
||||
match List.nth_opt typs' n with
|
||||
| Some t1n ->
|
||||
unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n
|
||||
| None ->
|
||||
Errors.raise_spanned_error (Ast.pos e1)
|
||||
"Expression should have a tuple type with at least %d elements but \
|
||||
only has %d"
|
||||
n (List.length typs)
|
||||
only has %d"
|
||||
n (List.length typs)
|
||||
end
|
||||
| A.EInj (e1, n, e_name, ts) ->
|
||||
let ts' =
|
||||
List.map unionfind_of_typ ts
|
||||
List.map ast_to_typ ts
|
||||
in
|
||||
let ts_n =
|
||||
match List.nth_opt ts' n with
|
||||
@ -470,7 +461,7 @@ and typecheck_expr_top_down
|
||||
has %d"
|
||||
n (List.length ts)
|
||||
in
|
||||
let e1' = typecheck_expr_top_down ctx env e1 ts_n in
|
||||
let+ e1' = typecheck_expr_top_down ctx env e1 ts_n in
|
||||
unify_and_mark (A.EInj (e1', n, e_name, ts))
|
||||
(unionfind_make (TEnum (ts', e_name)))
|
||||
| A.EMatch (e1, es, e_name) ->
|
||||
@ -483,46 +474,45 @@ and typecheck_expr_top_down
|
||||
(unionfind_make ~pos:e1 (TEnum (enum_cases, e_name)))
|
||||
in
|
||||
let t_ret = unionfind_make ~pos:e (TAny (Any.fresh ())) in
|
||||
let es' =
|
||||
List.map2 (fun es' enum_t ->
|
||||
let+ e1' = e1'
|
||||
and+ es' =
|
||||
bmap2 (fun es' enum_t ->
|
||||
typecheck_expr_top_down ctx env es'
|
||||
(unionfind_make ~pos:es' (TArrow (enum_t, t_ret))))
|
||||
es enum_cases
|
||||
in
|
||||
unify_and_mark (EMatch (e1', es', e_name)) t_ret
|
||||
| A.EAbs (binder, t_args) ->
|
||||
(* Bindlib.box binder |> Bindlib.mbind_apply *)
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
if Array.length xs <> List.length t_args then
|
||||
if Bindlib.mbinder_arity binder <> List.length t_args then
|
||||
Errors.raise_spanned_error (A.pos e)
|
||||
"function has %d variables but was supplied %d types"
|
||||
(Array.length xs) (List.length t_args)
|
||||
(Bindlib.mbinder_arity binder) (List.length t_args)
|
||||
else
|
||||
let xs' = Array.map translate_var xs in
|
||||
let xstaus =
|
||||
List.map2
|
||||
(fun x t_arg ->
|
||||
x, UnionFind.make (Marked.map_under_mark ast_to_typ t_arg))
|
||||
(Array.to_list xs) t_args
|
||||
in
|
||||
let env =
|
||||
List.fold_left
|
||||
(fun env (x, t_arg) -> A.VarMap.add (A.Var.t x) t_arg env)
|
||||
env xstaus
|
||||
in
|
||||
let body' = typecheck_expr_bottom_up ctx env body in
|
||||
let t_func =
|
||||
List.fold_right
|
||||
(fun (_, t_arg) acc ->
|
||||
unionfind_make (TArrow (t_arg, acc)))
|
||||
xstaus (ty body')
|
||||
in
|
||||
(* TODO: check this use of binders *)
|
||||
let binder' = Bindlib.unbox (Bindlib.bind_mvar xs' (Bindlib.box body')) in
|
||||
unify_and_mark (EAbs (binder', t_args)) t_func
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs' = Array.map translate_var xs in
|
||||
let xstaus =
|
||||
List.map2
|
||||
(fun x t_arg ->
|
||||
x, ast_to_typ t_arg)
|
||||
(Array.to_list xs) t_args
|
||||
in
|
||||
let env =
|
||||
List.fold_left
|
||||
(fun env (x, t_arg) -> A.VarMap.add (A.Var.t x) t_arg env)
|
||||
env xstaus
|
||||
in
|
||||
let body' = typecheck_expr_bottom_up ctx env body in
|
||||
let t_func =
|
||||
List.fold_right
|
||||
(fun (_, t_arg) acc ->
|
||||
unionfind_make (TArrow (t_arg, acc)))
|
||||
xstaus (box_ty body')
|
||||
in
|
||||
let+ binder' = Bindlib.bind_mvar xs' body' in
|
||||
unify_and_mark (EAbs (binder', t_args)) t_func
|
||||
| A.EApp (e1, args) ->
|
||||
let args' = List.map (typecheck_expr_bottom_up ctx env) args in
|
||||
let e1' = typecheck_expr_bottom_up ctx env e1 in
|
||||
let+ args' = bmap (typecheck_expr_bottom_up ctx env) args
|
||||
and+ e1' = typecheck_expr_bottom_up ctx env e1 in
|
||||
let t_func =
|
||||
List.fold_right
|
||||
(fun arg acc -> unionfind_make (TArrow (ty arg, acc)))
|
||||
@ -531,35 +521,35 @@ and typecheck_expr_top_down
|
||||
unify_and_mark (EApp (e1', args')) t_func
|
||||
| A.EOp op as e1 ->
|
||||
let op_typ = op_type (add_pos e op) in
|
||||
unify_and_mark e1 op_typ
|
||||
Bindlib.box (unify_and_mark e1 op_typ)
|
||||
| A.EDefault (excepts, just, cons) ->
|
||||
let just' = typecheck_expr_top_down ctx env just (unionfind_make ~pos:just (TLit TBool)) in
|
||||
let cons' = typecheck_expr_top_down ctx env cons tau in
|
||||
let excepts' =
|
||||
List.map (fun except -> typecheck_expr_top_down ctx env except tau)
|
||||
let+ just' = typecheck_expr_top_down ctx env just (unionfind_make ~pos:just (TLit TBool))
|
||||
and+ cons' = typecheck_expr_top_down ctx env cons tau
|
||||
and+ excepts' =
|
||||
bmap (fun except -> typecheck_expr_top_down ctx env except tau)
|
||||
excepts
|
||||
in
|
||||
mark (A.EDefault (excepts', just', cons'))
|
||||
| A.EIfThenElse (cond, et, ef) ->
|
||||
let cond' = typecheck_expr_top_down ctx env cond
|
||||
(unionfind_make ~pos:cond (TLit TBool)) in
|
||||
let et' = typecheck_expr_top_down ctx env et tau in
|
||||
let ef' = typecheck_expr_top_down ctx env ef tau in
|
||||
let+ cond' = typecheck_expr_top_down ctx env cond
|
||||
(unionfind_make ~pos:cond (TLit TBool))
|
||||
and+ et' = typecheck_expr_top_down ctx env et tau
|
||||
and+ ef' = typecheck_expr_top_down ctx env ef tau in
|
||||
mark (A.EIfThenElse (cond', et', ef'))
|
||||
| A.EAssert e1 ->
|
||||
let e1' = typecheck_expr_top_down ctx env e1
|
||||
let+ e1' = typecheck_expr_top_down ctx env e1
|
||||
(unionfind_make ~pos:e1 (TLit TBool)) in
|
||||
unify_and_mark (EAssert e1') (unionfind_make ~pos:e1 (TLit TUnit))
|
||||
| A.ErrorOnEmpty e1 ->
|
||||
let e1' = typecheck_expr_top_down ctx env e1 tau in
|
||||
let+ e1' = typecheck_expr_top_down ctx env e1 tau in
|
||||
mark (A.ErrorOnEmpty e1')
|
||||
| A.EArray es ->
|
||||
let cell_type = unionfind_make (TAny (Any.fresh ())) in
|
||||
let es' =
|
||||
List.map
|
||||
let+ es' =
|
||||
bmap
|
||||
(fun e1 ->
|
||||
let e1' = typecheck_expr_bottom_up ctx env e1 in
|
||||
unify ctx cell_type (ty e1');
|
||||
unify ctx cell_type (box_ty e1');
|
||||
e1')
|
||||
es
|
||||
in
|
||||
@ -577,7 +567,7 @@ and typecheck_expr_top_down
|
||||
(* Infer the type of an expression *)
|
||||
let infer_types (ctx : Ast.decl_ctx) (e : 'm A.marked_expr) : Ast.typed Ast.marked_expr
|
||||
=
|
||||
typecheck_expr_bottom_up ctx A.VarMap.empty e
|
||||
Bindlib.unbox @@ typecheck_expr_bottom_up ctx A.VarMap.empty e
|
||||
|
||||
let infer_type (type m) ctx (e: m A.marked_expr) =
|
||||
match Marked.get_mark e with
|
||||
@ -591,14 +581,13 @@ let check_type
|
||||
(tau : A.typ Marked.pos) =
|
||||
(* todo: consider using the already inferred type if ['m] = [typed] *)
|
||||
ignore @@
|
||||
typecheck_expr_top_down ctx A.VarMap.empty e
|
||||
(UnionFind.make (Marked.map_under_mark ast_to_typ tau))
|
||||
typecheck_expr_top_down ctx A.VarMap.empty e (ast_to_typ tau)
|
||||
|
||||
let infer_types_program prg =
|
||||
let scopes =
|
||||
Bindlib.unbox @@
|
||||
A.map_exprs_in_scopes
|
||||
~f:(fun e -> Bindlib.box (typecheck_expr_bottom_up prg.A.decl_ctx A.VarMap.empty e))
|
||||
~f:(typecheck_expr_bottom_up prg.A.decl_ctx A.VarMap.empty)
|
||||
~varf:translate_var
|
||||
prg.A.scopes
|
||||
in
|
||||
|
Loading…
Reference in New Issue
Block a user