mirror of
https://github.com/CatalaLang/catala.git
synced 2024-09-20 00:41:05 +03:00
Generalise the typer
This moves dcalc/typing.ml to shared_ast, and generalises the input type, but without yet implementing the extra cases (these are all `assert false`): it's just a first step.
This commit is contained in:
parent
0bb9cce341
commit
b37a6c3703
@ -219,7 +219,7 @@ let driver source_file (options : Cli.options) : int =
|
|||||||
| ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc
|
| ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc
|
||||||
| `Proof | `Plugin _ ) as backend -> (
|
| `Proof | `Plugin _ ) as backend -> (
|
||||||
Cli.debug_print "Typechecking...";
|
Cli.debug_print "Typechecking...";
|
||||||
let prgm = Dcalc.Typing.infer_types_program prgm in
|
let prgm = Shared_ast.Typing.infer_types_program prgm in
|
||||||
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
|
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
|
||||||
(Print.typ prgm.decl_ctx) typ); *)
|
(Print.typ prgm.decl_ctx) typ); *)
|
||||||
match backend with
|
match backend with
|
||||||
|
@ -264,7 +264,7 @@ type typed = { pos : Pos.t; ty : typ }
|
|||||||
type _ mark = Untyped : untyped -> untyped mark | Typed : typed -> typed mark
|
type _ mark = Untyped : untyped -> untyped mark | Typed : typed -> typed mark
|
||||||
|
|
||||||
(** Useful for errors and printing, for example *)
|
(** Useful for errors and printing, for example *)
|
||||||
type any_expr = AnyExpr : (_ any, _ mark) gexpr -> any_expr
|
type any_expr = AnyExpr : (_, _ mark) gexpr -> any_expr
|
||||||
|
|
||||||
(** {2 Higher-level program structure} *)
|
(** {2 Higher-level program structure} *)
|
||||||
|
|
||||||
|
@ -20,3 +20,4 @@ module Expr = Expr
|
|||||||
module Scope = Scope
|
module Scope = Scope
|
||||||
module Program = Program
|
module Program = Program
|
||||||
module Print = Print
|
module Print = Print
|
||||||
|
module Typing = Typing
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
inference using the classical W algorithm with union-find unification. *)
|
inference using the classical W algorithm with union-find unification. *)
|
||||||
|
|
||||||
open Utils
|
open Utils
|
||||||
module A = Shared_ast
|
module A = Definitions
|
||||||
|
|
||||||
module Any =
|
module Any =
|
||||||
Utils.Uid.Make
|
Utils.Uid.Make
|
||||||
@ -33,8 +33,8 @@ module Any =
|
|||||||
()
|
()
|
||||||
|
|
||||||
type unionfind_typ = naked_typ Marked.pos UnionFind.elem
|
type unionfind_typ = naked_typ Marked.pos UnionFind.elem
|
||||||
(** We do not reuse {!type: Dcalc.Ast.naked_typ} because we have to include a
|
(** We do not reuse {!type: Shared_ast.typ} because we have to include a new
|
||||||
new [TAny] variant. Indeed, error terms can have any type and this has to be
|
[TAny] variant. Indeed, error terms can have any type and this has to be
|
||||||
captured by the type sytem. *)
|
captured by the type sytem. *)
|
||||||
|
|
||||||
and naked_typ =
|
and naked_typ =
|
||||||
@ -90,7 +90,7 @@ let rec format_typ
|
|||||||
in
|
in
|
||||||
let naked_typ = UnionFind.get (UnionFind.find naked_typ) in
|
let naked_typ = UnionFind.get (UnionFind.find naked_typ) in
|
||||||
match Marked.unmark naked_typ with
|
match Marked.unmark naked_typ with
|
||||||
| TLit l -> Format.fprintf fmt "%a" A.Print.tlit l
|
| TLit l -> Format.fprintf fmt "%a" Print.tlit l
|
||||||
| TTuple ts ->
|
| TTuple ts ->
|
||||||
Format.fprintf fmt "@[<hov 2>(%a)]"
|
Format.fprintf fmt "@[<hov 2>(%a)]"
|
||||||
(Format.pp_print_list
|
(Format.pp_print_list
|
||||||
@ -195,6 +195,17 @@ let handle_type_error ctx e t1 t2 =
|
|||||||
(Cli.format_with_style [ANSITerminal.blue; ANSITerminal.Bold])
|
(Cli.format_with_style [ANSITerminal.blue; ANSITerminal.Bold])
|
||||||
"-->" t2_s ()
|
"-->" t2_s ()
|
||||||
|
|
||||||
|
let lit_type (type a) (lit : a A.glit) : naked_typ =
|
||||||
|
match lit with
|
||||||
|
| LBool _ -> TLit TBool
|
||||||
|
| LInt _ -> TLit TInt
|
||||||
|
| LRat _ -> TLit TRat
|
||||||
|
| LMoney _ -> TLit TMoney
|
||||||
|
| LDate _ -> TLit TDate
|
||||||
|
| LDuration _ -> TLit TDuration
|
||||||
|
| LUnit -> TLit TUnit
|
||||||
|
| LEmptyError -> TAny (Any.fresh ())
|
||||||
|
|
||||||
(** Operators have a single type, instead of being polymorphic with constraints.
|
(** Operators have a single type, instead of being polymorphic with constraints.
|
||||||
This allows us to have a simpler type system, while we argue the syntactic
|
This allows us to have a simpler type system, while we argue the syntactic
|
||||||
burden of operator annotations helps the programmer visualize the type flow
|
burden of operator annotations helps the programmer visualize the type flow
|
||||||
@ -265,9 +276,9 @@ let op_type (op : A.operator Marked.pos) : unionfind_typ =
|
|||||||
|
|
||||||
(** {1 Double-directed typing} *)
|
(** {1 Double-directed typing} *)
|
||||||
|
|
||||||
type 'e env = ('e, unionfind_typ) A.Var.Map.t
|
type 'e env = ('e, unionfind_typ) Var.Map.t
|
||||||
|
|
||||||
let add_pos e ty = Marked.mark (A.Expr.pos e) ty
|
let add_pos e ty = Marked.mark (Expr.pos e) ty
|
||||||
let ty (_, { uf; _ }) = uf
|
let ty (_, { uf; _ }) = uf
|
||||||
let ( let+ ) x f = Bindlib.box_apply f x
|
let ( let+ ) x f = Bindlib.box_apply f x
|
||||||
let ( and+ ) x1 x2 = Bindlib.box_pair x1 x2
|
let ( and+ ) x1 x2 = Bindlib.box_pair x1 x2
|
||||||
@ -293,41 +304,40 @@ let bmap2 (f : 'a -> 'b -> 'c Bindlib.box) (es : 'a list) (xs : 'b list) :
|
|||||||
let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e)
|
let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e)
|
||||||
|
|
||||||
(** Infers the most permissive type from an expression *)
|
(** Infers the most permissive type from an expression *)
|
||||||
let rec typecheck_expr_bottom_up
|
let rec typecheck_expr_bottom_up :
|
||||||
(ctx : A.decl_ctx)
|
type a.
|
||||||
(env : 'm Ast.expr env)
|
A.decl_ctx ->
|
||||||
(e : 'm Ast.expr) : (A.dcalc, mark) A.gexpr Bindlib.box =
|
(a, 'm A.mark) A.gexpr env ->
|
||||||
|
(a, 'm A.mark) A.gexpr ->
|
||||||
|
(a, mark) A.gexpr A.box =
|
||||||
|
fun ctx env e ->
|
||||||
(* Cli.debug_format "Looking for type of %a" (Expr.format ~debug:true ctx)
|
(* Cli.debug_format "Looking for type of %a" (Expr.format ~debug:true ctx)
|
||||||
e; *)
|
e; *)
|
||||||
let pos_e = A.Expr.pos e in
|
let pos_e = Expr.pos e in
|
||||||
let mark (e : (A.dcalc, mark) A.naked_gexpr) uf =
|
let mark e uf = Marked.mark { uf; pos = pos_e } e in
|
||||||
Marked.mark { uf; pos = pos_e } e
|
|
||||||
in
|
|
||||||
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
||||||
let mark_with_uf e1 ?pos ty = mark e1 (unionfind_make ?pos ty) in
|
let mark_with_uf e1 ?pos ty = mark e1 (unionfind_make ?pos ty) in
|
||||||
match Marked.unmark e with
|
match Marked.unmark e with
|
||||||
|
| A.ELocation _ -> assert false
|
||||||
|
| A.EStruct _ -> assert false
|
||||||
|
| A.EStructAccess _ -> assert false
|
||||||
|
| A.EEnumInj _ -> assert false
|
||||||
|
| A.EMatchS _ -> assert false
|
||||||
|
| A.ERaise _ -> assert false
|
||||||
|
| A.ECatch _ -> assert false
|
||||||
| A.EVar v -> begin
|
| A.EVar v -> begin
|
||||||
match A.Var.Map.find_opt v env with
|
match Var.Map.find_opt v env with
|
||||||
| Some t ->
|
| Some t ->
|
||||||
let+ v' = Bindlib.box_var (A.Var.translate v) in
|
let+ v' = Bindlib.box_var (Var.translate v) in
|
||||||
mark v' t
|
mark v' t
|
||||||
| None ->
|
| None ->
|
||||||
Errors.raise_spanned_error (A.Expr.pos e)
|
Errors.raise_spanned_error (Expr.pos e)
|
||||||
"Variable %s not found in the current context." (Bindlib.name_of v)
|
"Variable %s not found in the current context." (Bindlib.name_of v)
|
||||||
end
|
end
|
||||||
| A.ELit (LBool _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TBool)
|
| A.ELit lit as e1 -> Bindlib.box @@ mark_with_uf e1 (lit_type lit)
|
||||||
| A.ELit (LInt _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TInt)
|
|
||||||
| A.ELit (LRat _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TRat)
|
|
||||||
| A.ELit (LMoney _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TMoney)
|
|
||||||
| A.ELit (LDate _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TDate)
|
|
||||||
| A.ELit (LDuration _) as e1 ->
|
|
||||||
Bindlib.box @@ mark_with_uf e1 (TLit TDuration)
|
|
||||||
| A.ELit LUnit as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TUnit)
|
|
||||||
| A.ELit LEmptyError as e1 ->
|
|
||||||
Bindlib.box @@ mark_with_uf e1 (TAny (Any.fresh ()))
|
|
||||||
| A.ETuple (es, None) ->
|
| A.ETuple (es, None) ->
|
||||||
let+ es = bmap (typecheck_expr_bottom_up ctx env) es in
|
let+ es = bmap (typecheck_expr_bottom_up ctx env) es in
|
||||||
mark_with_uf (ETuple (es, None)) (TTuple (List.map ty es))
|
mark_with_uf (A.ETuple (es, None)) (TTuple (List.map ty es))
|
||||||
| A.ETuple (es, Some s_name) ->
|
| A.ETuple (es, Some s_name) ->
|
||||||
let tys =
|
let tys =
|
||||||
List.map
|
List.map
|
||||||
@ -335,13 +345,13 @@ let rec typecheck_expr_bottom_up
|
|||||||
(A.StructMap.find s_name ctx.A.ctx_structs)
|
(A.StructMap.find s_name ctx.A.ctx_structs)
|
||||||
in
|
in
|
||||||
let+ es = bmap2 (typecheck_expr_top_down ctx env) tys es in
|
let+ es = bmap2 (typecheck_expr_top_down ctx env) tys es in
|
||||||
mark_with_uf (ETuple (es, Some s_name)) (TStruct s_name)
|
mark_with_uf (A.ETuple (es, Some s_name)) (TStruct s_name)
|
||||||
| A.ETupleAccess (e1, n, s, typs) -> begin
|
| A.ETupleAccess (e1, n, s, typs) -> begin
|
||||||
let utyps = List.map ast_to_typ typs in
|
let utyps = List.map ast_to_typ typs in
|
||||||
let tuple_ty = match s with None -> TTuple utyps | Some s -> TStruct s in
|
let tuple_ty = match s with None -> TTuple utyps | Some s -> TStruct s in
|
||||||
let+ e1 = typecheck_expr_top_down ctx env (unionfind_make tuple_ty) e1 in
|
let+ e1 = typecheck_expr_top_down ctx env (unionfind_make tuple_ty) e1 in
|
||||||
match List.nth_opt utyps n with
|
match List.nth_opt utyps n with
|
||||||
| Some t' -> mark (ETupleAccess (e1, n, s, typs)) t'
|
| Some t' -> mark (A.ETupleAccess (e1, n, s, typs)) t'
|
||||||
| None ->
|
| None ->
|
||||||
Errors.raise_spanned_error (Marked.get_mark e1).pos
|
Errors.raise_spanned_error (Marked.get_mark e1).pos
|
||||||
"Expression should have a tuple type with at least %d elements but \
|
"Expression should have a tuple type with at least %d elements but \
|
||||||
@ -354,7 +364,7 @@ let rec typecheck_expr_bottom_up
|
|||||||
match List.nth_opt ts' n with
|
match List.nth_opt ts' n with
|
||||||
| Some ts_n -> ts_n
|
| Some ts_n -> ts_n
|
||||||
| None ->
|
| None ->
|
||||||
Errors.raise_spanned_error (A.Expr.pos e)
|
Errors.raise_spanned_error (Expr.pos e)
|
||||||
"Expression should have a sum type with at least %d cases but only \
|
"Expression should have a sum type with at least %d cases but only \
|
||||||
has %d"
|
has %d"
|
||||||
n (List.length ts')
|
n (List.length ts')
|
||||||
@ -376,19 +386,19 @@ let rec typecheck_expr_bottom_up
|
|||||||
es')
|
es')
|
||||||
es enum_cases
|
es enum_cases
|
||||||
in
|
in
|
||||||
mark (EMatch (e1', es', e_name)) t_ret
|
mark (A.EMatch (e1', es', e_name)) t_ret
|
||||||
| A.EAbs (binder, taus) ->
|
| A.EAbs (binder, taus) ->
|
||||||
if Bindlib.mbinder_arity binder <> List.length taus then
|
if Bindlib.mbinder_arity binder <> List.length taus then
|
||||||
Errors.raise_spanned_error (A.Expr.pos e)
|
Errors.raise_spanned_error (Expr.pos e)
|
||||||
"function has %d variables but was supplied %d types"
|
"function has %d variables but was supplied %d types"
|
||||||
(Bindlib.mbinder_arity binder)
|
(Bindlib.mbinder_arity binder)
|
||||||
(List.length taus)
|
(List.length taus)
|
||||||
else
|
else
|
||||||
let xs, body = Bindlib.unmbind binder in
|
let xs, body = Bindlib.unmbind binder in
|
||||||
let xs' = Array.map A.Var.translate xs in
|
let xs' = Array.map Var.translate xs in
|
||||||
let xstaus = List.mapi (fun i tau -> xs.(i), ast_to_typ tau) taus in
|
let xstaus = List.mapi (fun i tau -> xs.(i), ast_to_typ tau) taus in
|
||||||
let env =
|
let env =
|
||||||
List.fold_left (fun env (x, tau) -> A.Var.Map.add x tau env) env xstaus
|
List.fold_left (fun env (x, tau) -> Var.Map.add x tau env) env xstaus
|
||||||
in
|
in
|
||||||
let body' = typecheck_expr_bottom_up ctx env body in
|
let body' = typecheck_expr_bottom_up ctx env body in
|
||||||
let t_func =
|
let t_func =
|
||||||
@ -397,7 +407,7 @@ let rec typecheck_expr_bottom_up
|
|||||||
xstaus (box_ty body')
|
xstaus (box_ty body')
|
||||||
in
|
in
|
||||||
let+ binder' = Bindlib.bind_mvar xs' body' in
|
let+ binder' = Bindlib.bind_mvar xs' body' in
|
||||||
mark (EAbs (binder', taus)) t_func
|
mark (A.EAbs (binder', taus)) t_func
|
||||||
| A.EApp (e1, args) ->
|
| A.EApp (e1, args) ->
|
||||||
let args' = bmap (typecheck_expr_bottom_up ctx env) args in
|
let args' = bmap (typecheck_expr_bottom_up ctx env) args in
|
||||||
let t_ret = unionfind_make (TAny (Any.fresh ())) in
|
let t_ret = unionfind_make (TAny (Any.fresh ())) in
|
||||||
@ -409,7 +419,7 @@ let rec typecheck_expr_bottom_up
|
|||||||
in
|
in
|
||||||
let+ e1' = typecheck_expr_bottom_up ctx env e1 and+ args' in
|
let+ e1' = typecheck_expr_bottom_up ctx env e1 and+ args' in
|
||||||
unify ctx e (ty e1') t_func;
|
unify ctx e (ty e1') t_func;
|
||||||
mark (EApp (e1', args')) t_ret
|
mark (A.EApp (e1', args')) t_ret
|
||||||
| A.EOp op as e1 -> Bindlib.box @@ mark e1 (op_type (Marked.mark pos_e op))
|
| A.EOp op as e1 -> Bindlib.box @@ mark e1 (op_type (Marked.mark pos_e op))
|
||||||
| A.EDefault (excepts, just, cons) ->
|
| A.EDefault (excepts, just, cons) ->
|
||||||
let just' =
|
let just' =
|
||||||
@ -456,46 +466,42 @@ let rec typecheck_expr_bottom_up
|
|||||||
mark_with_uf (A.EArray es') (TArray cell_type)
|
mark_with_uf (A.EArray es') (TArray cell_type)
|
||||||
|
|
||||||
(** Checks whether the expression can be typed with the provided type *)
|
(** Checks whether the expression can be typed with the provided type *)
|
||||||
and typecheck_expr_top_down
|
and typecheck_expr_top_down :
|
||||||
(ctx : A.decl_ctx)
|
type a.
|
||||||
(env : 'm Ast.expr env)
|
A.decl_ctx ->
|
||||||
(tau : unionfind_typ)
|
(a, 'm A.mark) A.gexpr env ->
|
||||||
(e : 'm Ast.expr) : (A.dcalc, mark) A.gexpr Bindlib.box =
|
unionfind_typ ->
|
||||||
|
(a, 'm A.mark) A.gexpr ->
|
||||||
|
(a, mark) A.gexpr Bindlib.box =
|
||||||
|
fun ctx env tau e ->
|
||||||
(* Cli.debug_format "Propagating type %a for naked_expr %a" (format_typ ctx)
|
(* Cli.debug_format "Propagating type %a for naked_expr %a" (format_typ ctx)
|
||||||
tau (Expr.format ctx) e; *)
|
tau (Expr.format ctx) e; *)
|
||||||
let pos_e = A.Expr.pos e in
|
let pos_e = Expr.pos e in
|
||||||
let mark e = Marked.mark { uf = tau; pos = pos_e } e in
|
let mark e = Marked.mark { uf = tau; pos = pos_e } e in
|
||||||
let unify_and_mark (e' : (A.dcalc, mark) A.naked_gexpr) tau' =
|
let unify_and_mark (e' : (a, mark) A.naked_gexpr) tau' =
|
||||||
unify ctx e tau' tau;
|
unify ctx e tau' tau;
|
||||||
Marked.mark { uf = tau; pos = pos_e } e'
|
Marked.mark { uf = tau; pos = pos_e } e'
|
||||||
in
|
in
|
||||||
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
||||||
match Marked.unmark e with
|
match Marked.unmark e with
|
||||||
|
| A.ELocation _ -> assert false
|
||||||
|
| A.EStruct _ -> assert false
|
||||||
|
| A.EStructAccess _ -> assert false
|
||||||
|
| A.EEnumInj _ -> assert false
|
||||||
|
| A.EMatchS _ -> assert false
|
||||||
|
| A.ERaise _ -> assert false
|
||||||
|
| A.ECatch _ -> assert false
|
||||||
| A.EVar v -> begin
|
| A.EVar v -> begin
|
||||||
match A.Var.Map.find_opt v env with
|
match Var.Map.find_opt v env with
|
||||||
| Some tau' ->
|
| Some tau' ->
|
||||||
let+ v' = Bindlib.box_var (A.Var.translate v) in
|
let+ v' = Bindlib.box_var (Var.translate v) in
|
||||||
unify_and_mark v' tau'
|
unify_and_mark v' tau'
|
||||||
| None ->
|
| None ->
|
||||||
Errors.raise_spanned_error pos_e
|
Errors.raise_spanned_error pos_e
|
||||||
"Variable %s not found in the current context" (Bindlib.name_of v)
|
"Variable %s not found in the current context" (Bindlib.name_of v)
|
||||||
end
|
end
|
||||||
| A.ELit (LBool _) as e1 ->
|
| A.ELit lit as e1 ->
|
||||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TBool))
|
Bindlib.box @@ unify_and_mark e1 (unionfind_make (lit_type lit))
|
||||||
| A.ELit (LInt _) as e1 ->
|
|
||||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TInt))
|
|
||||||
| A.ELit (LRat _) as e1 ->
|
|
||||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TRat))
|
|
||||||
| A.ELit (LMoney _) as e1 ->
|
|
||||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TMoney))
|
|
||||||
| A.ELit (LDate _) as e1 ->
|
|
||||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TDate))
|
|
||||||
| A.ELit (LDuration _) as e1 ->
|
|
||||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TDuration))
|
|
||||||
| A.ELit LUnit as e1 ->
|
|
||||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TUnit))
|
|
||||||
| A.ELit LEmptyError as e1 ->
|
|
||||||
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TAny (Any.fresh ())))
|
|
||||||
| A.ETuple (es, None) ->
|
| A.ETuple (es, None) ->
|
||||||
let+ es' = bmap (typecheck_expr_bottom_up ctx env) es in
|
let+ es' = bmap (typecheck_expr_bottom_up ctx env) es in
|
||||||
unify_and_mark
|
unify_and_mark
|
||||||
@ -518,7 +524,7 @@ and typecheck_expr_top_down
|
|||||||
match List.nth_opt typs' n with
|
match List.nth_opt typs' n with
|
||||||
| Some t1n -> unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n
|
| Some t1n -> unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n
|
||||||
| None ->
|
| None ->
|
||||||
Errors.raise_spanned_error (A.Expr.pos e1)
|
Errors.raise_spanned_error (Expr.pos e1)
|
||||||
"Expression should have a tuple type with at least %d elements but \
|
"Expression should have a tuple type with at least %d elements but \
|
||||||
only has %d"
|
only has %d"
|
||||||
n (List.length typs)
|
n (List.length typs)
|
||||||
@ -529,7 +535,7 @@ and typecheck_expr_top_down
|
|||||||
match List.nth_opt ts' n with
|
match List.nth_opt ts' n with
|
||||||
| Some ts_n -> ts_n
|
| Some ts_n -> ts_n
|
||||||
| None ->
|
| None ->
|
||||||
Errors.raise_spanned_error (A.Expr.pos e)
|
Errors.raise_spanned_error (Expr.pos e)
|
||||||
"Expression should have a sum type with at least %d cases but only \
|
"Expression should have a sum type with at least %d cases but only \
|
||||||
has %d"
|
has %d"
|
||||||
n (List.length ts)
|
n (List.length ts)
|
||||||
@ -556,19 +562,19 @@ and typecheck_expr_top_down
|
|||||||
unify_and_mark (EMatch (e1', es', e_name)) t_ret
|
unify_and_mark (EMatch (e1', es', e_name)) t_ret
|
||||||
| A.EAbs (binder, t_args) ->
|
| A.EAbs (binder, t_args) ->
|
||||||
if Bindlib.mbinder_arity binder <> List.length t_args then
|
if Bindlib.mbinder_arity binder <> List.length t_args then
|
||||||
Errors.raise_spanned_error (A.Expr.pos e)
|
Errors.raise_spanned_error (Expr.pos e)
|
||||||
"function has %d variables but was supplied %d types"
|
"function has %d variables but was supplied %d types"
|
||||||
(Bindlib.mbinder_arity binder)
|
(Bindlib.mbinder_arity binder)
|
||||||
(List.length t_args)
|
(List.length t_args)
|
||||||
else
|
else
|
||||||
let xs, body = Bindlib.unmbind binder in
|
let xs, body = Bindlib.unmbind binder in
|
||||||
let xs' = Array.map A.Var.translate xs in
|
let xs' = Array.map Var.translate xs in
|
||||||
let xstaus =
|
let xstaus =
|
||||||
List.map2 (fun x t_arg -> x, ast_to_typ t_arg) (Array.to_list xs) t_args
|
List.map2 (fun x t_arg -> x, ast_to_typ t_arg) (Array.to_list xs) t_args
|
||||||
in
|
in
|
||||||
let env =
|
let env =
|
||||||
List.fold_left
|
List.fold_left
|
||||||
(fun env (x, t_arg) -> A.Var.Map.add x t_arg env)
|
(fun env (x, t_arg) -> Var.Map.add x t_arg env)
|
||||||
env xstaus
|
env xstaus
|
||||||
in
|
in
|
||||||
let body' = typecheck_expr_bottom_up ctx env body in
|
let body' = typecheck_expr_bottom_up ctx env body in
|
||||||
@ -643,21 +649,21 @@ let wrap ctx f e =
|
|||||||
let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
|
let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
|
||||||
|
|
||||||
(* Infer the type of an expression *)
|
(* Infer the type of an expression *)
|
||||||
let infer_types (ctx : A.decl_ctx) (e : 'm Ast.expr) :
|
let infer_types (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) :
|
||||||
A.typed Ast.expr Bindlib.box =
|
(a, A.typed A.mark) A.gexpr A.box =
|
||||||
A.Expr.map_marks ~f:get_ty_mark
|
Expr.map_marks ~f:get_ty_mark
|
||||||
@@ wrap ctx (typecheck_expr_bottom_up ctx A.Var.Map.empty) e
|
@@ wrap ctx (typecheck_expr_bottom_up ctx Var.Map.empty) e
|
||||||
|
|
||||||
let infer_type (type m) ctx (e : m Ast.expr) =
|
let infer_type (type a m) ctx (e : (a, m A.mark) A.gexpr) =
|
||||||
match Marked.get_mark e with
|
match Marked.get_mark e with
|
||||||
| A.Typed { ty; _ } -> ty
|
| A.Typed { ty; _ } -> ty
|
||||||
| A.Untyped _ -> A.Expr.ty (Bindlib.unbox (infer_types ctx e))
|
| A.Untyped _ -> Expr.ty (Bindlib.unbox (infer_types ctx e))
|
||||||
|
|
||||||
(** Typechecks an expression given an expected type *)
|
(** Typechecks an expression given an expected type *)
|
||||||
let check_type (ctx : A.decl_ctx) (e : 'm Ast.expr) (tau : A.typ) =
|
let check_type (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) (tau : A.typ) =
|
||||||
(* todo: consider using the already inferred type if ['m] = [typed] *)
|
(* todo: consider using the already inferred type if ['m] = [typed] *)
|
||||||
ignore
|
ignore
|
||||||
@@ wrap ctx (typecheck_expr_top_down ctx A.Var.Map.empty (ast_to_typ tau)) e
|
@@ wrap ctx (typecheck_expr_top_down ctx Var.Map.empty (ast_to_typ tau)) e
|
||||||
|
|
||||||
let infer_types_program prg =
|
let infer_types_program prg =
|
||||||
let ctx = prg.A.decl_ctx in
|
let ctx = prg.A.decl_ctx in
|
||||||
@ -686,7 +692,7 @@ let infer_types_program prg =
|
|||||||
let rec process_scope_body_expr env = function
|
let rec process_scope_body_expr env = function
|
||||||
| A.Result e ->
|
| A.Result e ->
|
||||||
let e' = wrap ctx (typecheck_expr_top_down ctx env ty_out) e in
|
let e' = wrap ctx (typecheck_expr_top_down ctx env ty_out) e in
|
||||||
let e' = A.Expr.map_marks ~f:get_ty_mark e' in
|
let e' = Expr.map_marks ~f:get_ty_mark e' in
|
||||||
Bindlib.box_apply (fun e -> A.Result e) e'
|
Bindlib.box_apply (fun e -> A.Result e) e'
|
||||||
| A.ScopeLet
|
| A.ScopeLet
|
||||||
{
|
{
|
||||||
@ -704,9 +710,9 @@ let infer_types_program prg =
|
|||||||
[unify] parameters, which keeps location of the type as defined
|
[unify] parameters, which keeps location of the type as defined
|
||||||
instead of as inferred. *)
|
instead of as inferred. *)
|
||||||
let var, next = Bindlib.unbind scope_let_next in
|
let var, next = Bindlib.unbind scope_let_next in
|
||||||
let env = A.Var.Map.add var ty_e env in
|
let env = Var.Map.add var ty_e env in
|
||||||
let next = process_scope_body_expr env next in
|
let next = process_scope_body_expr env next in
|
||||||
let scope_let_next = Bindlib.bind_var (A.Var.translate var) next in
|
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
|
||||||
Bindlib.box_apply2
|
Bindlib.box_apply2
|
||||||
(fun scope_let_expr scope_let_next ->
|
(fun scope_let_expr scope_let_next ->
|
||||||
A.ScopeLet
|
A.ScopeLet
|
||||||
@ -717,20 +723,20 @@ let infer_types_program prg =
|
|||||||
scope_let_next;
|
scope_let_next;
|
||||||
scope_let_pos;
|
scope_let_pos;
|
||||||
})
|
})
|
||||||
(A.Expr.map_marks ~f:get_ty_mark e)
|
(Expr.map_marks ~f:get_ty_mark e)
|
||||||
scope_let_next
|
scope_let_next
|
||||||
in
|
in
|
||||||
let scope_body_expr =
|
let scope_body_expr =
|
||||||
let var, e = Bindlib.unbind body in
|
let var, e = Bindlib.unbind body in
|
||||||
let env = A.Var.Map.add var ty_in env in
|
let env = Var.Map.add var ty_in env in
|
||||||
let e' = process_scope_body_expr env e in
|
let e' = process_scope_body_expr env e in
|
||||||
Bindlib.bind_var (A.Var.translate var) e'
|
Bindlib.bind_var (Var.translate var) e'
|
||||||
in
|
in
|
||||||
let scope_next =
|
let scope_next =
|
||||||
let scope_var, next = Bindlib.unbind scope_next in
|
let scope_var, next = Bindlib.unbind scope_next in
|
||||||
let env = A.Var.Map.add scope_var ty_scope env in
|
let env = Var.Map.add scope_var ty_scope env in
|
||||||
let next' = process_scopes env next in
|
let next' = process_scopes env next in
|
||||||
Bindlib.bind_var (A.Var.translate scope_var) next'
|
Bindlib.bind_var (Var.translate scope_var) next'
|
||||||
in
|
in
|
||||||
Bindlib.box_apply2
|
Bindlib.box_apply2
|
||||||
(fun scope_body_expr scope_next ->
|
(fun scope_body_expr scope_next ->
|
||||||
@ -747,5 +753,5 @@ let infer_types_program prg =
|
|||||||
})
|
})
|
||||||
scope_body_expr scope_next
|
scope_body_expr scope_next
|
||||||
in
|
in
|
||||||
let scopes = Bindlib.unbox (process_scopes A.Var.Map.empty prg.scopes) in
|
let scopes = Bindlib.unbox (process_scopes Var.Map.empty prg.scopes) in
|
||||||
{ A.decl_ctx = ctx; scopes }
|
{ A.decl_ctx = ctx; scopes }
|
@ -17,15 +17,18 @@
|
|||||||
(** Typing for the default calculus. Because of the error terms, we perform type
|
(** Typing for the default calculus. Because of the error terms, we perform type
|
||||||
inference using the classical W algorithm with union-find unification. *)
|
inference using the classical W algorithm with union-find unification. *)
|
||||||
|
|
||||||
open Shared_ast
|
open Definitions
|
||||||
|
|
||||||
val infer_types : decl_ctx -> untyped Ast.expr -> typed Ast.expr Bindlib.box
|
val infer_types :
|
||||||
|
decl_ctx -> ('a, untyped mark) gexpr -> ('a, typed mark) gexpr box
|
||||||
(** Infers types everywhere on the given expression, and adds (or replaces) type
|
(** Infers types everywhere on the given expression, and adds (or replaces) type
|
||||||
annotations on each node *)
|
annotations on each node *)
|
||||||
|
|
||||||
val infer_type : decl_ctx -> 'm Ast.expr -> typ
|
val infer_type : decl_ctx -> ('a, 'm mark) gexpr -> typ
|
||||||
(** Gets the outer type of the given expression, using either the existing
|
(** Gets the outer type of the given expression, using either the existing
|
||||||
annotations or inference *)
|
annotations or inference *)
|
||||||
|
|
||||||
val check_type : decl_ctx -> 'm Ast.expr -> typ -> unit
|
val check_type : decl_ctx -> ('a, 'm mark) gexpr -> typ -> unit
|
||||||
val infer_types_program : untyped Ast.program -> typed Ast.program
|
|
||||||
|
val infer_types_program :
|
||||||
|
('a, untyped mark) gexpr program -> ('a, typed mark) gexpr program
|
Loading…
Reference in New Issue
Block a user