Generalise the typer

This moves dcalc/typing.ml to shared_ast, and generalises the input type, but
without yet implementing the extra cases (these are all `assert false`): it's
just a first step.
This commit is contained in:
Louis Gesbert 2022-09-13 15:20:13 +02:00
parent 0bb9cce341
commit b37a6c3703
5 changed files with 101 additions and 91 deletions

View File

@ -219,7 +219,7 @@ let driver source_file (options : Cli.options) : int =
| ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc
| `Proof | `Plugin _ ) as backend -> (
Cli.debug_print "Typechecking...";
let prgm = Dcalc.Typing.infer_types_program prgm in
let prgm = Shared_ast.Typing.infer_types_program prgm in
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
(Print.typ prgm.decl_ctx) typ); *)
match backend with

View File

@ -264,7 +264,7 @@ type typed = { pos : Pos.t; ty : typ }
type _ mark = Untyped : untyped -> untyped mark | Typed : typed -> typed mark
(** Useful for errors and printing, for example *)
type any_expr = AnyExpr : (_ any, _ mark) gexpr -> any_expr
type any_expr = AnyExpr : (_, _ mark) gexpr -> any_expr
(** {2 Higher-level program structure} *)

View File

@ -20,3 +20,4 @@ module Expr = Expr
module Scope = Scope
module Program = Program
module Print = Print
module Typing = Typing

View File

@ -18,7 +18,7 @@
inference using the classical W algorithm with union-find unification. *)
open Utils
module A = Shared_ast
module A = Definitions
module Any =
Utils.Uid.Make
@ -33,8 +33,8 @@ module Any =
()
type unionfind_typ = naked_typ Marked.pos UnionFind.elem
(** We do not reuse {!type: Dcalc.Ast.naked_typ} because we have to include a
new [TAny] variant. Indeed, error terms can have any type and this has to be
(** We do not reuse {!type: Shared_ast.typ} because we have to include a new
[TAny] variant. Indeed, error terms can have any type and this has to be
captured by the type sytem. *)
and naked_typ =
@ -90,7 +90,7 @@ let rec format_typ
in
let naked_typ = UnionFind.get (UnionFind.find naked_typ) in
match Marked.unmark naked_typ with
| TLit l -> Format.fprintf fmt "%a" A.Print.tlit l
| TLit l -> Format.fprintf fmt "%a" Print.tlit l
| TTuple ts ->
Format.fprintf fmt "@[<hov 2>(%a)]"
(Format.pp_print_list
@ -195,6 +195,17 @@ let handle_type_error ctx e t1 t2 =
(Cli.format_with_style [ANSITerminal.blue; ANSITerminal.Bold])
"-->" t2_s ()
let lit_type (type a) (lit : a A.glit) : naked_typ =
match lit with
| LBool _ -> TLit TBool
| LInt _ -> TLit TInt
| LRat _ -> TLit TRat
| LMoney _ -> TLit TMoney
| LDate _ -> TLit TDate
| LDuration _ -> TLit TDuration
| LUnit -> TLit TUnit
| LEmptyError -> TAny (Any.fresh ())
(** Operators have a single type, instead of being polymorphic with constraints.
This allows us to have a simpler type system, while we argue the syntactic
burden of operator annotations helps the programmer visualize the type flow
@ -265,9 +276,9 @@ let op_type (op : A.operator Marked.pos) : unionfind_typ =
(** {1 Double-directed typing} *)
type 'e env = ('e, unionfind_typ) A.Var.Map.t
type 'e env = ('e, unionfind_typ) Var.Map.t
let add_pos e ty = Marked.mark (A.Expr.pos e) ty
let add_pos e ty = Marked.mark (Expr.pos e) ty
let ty (_, { uf; _ }) = uf
let ( let+ ) x f = Bindlib.box_apply f x
let ( and+ ) x1 x2 = Bindlib.box_pair x1 x2
@ -293,41 +304,40 @@ let bmap2 (f : 'a -> 'b -> 'c Bindlib.box) (es : 'a list) (xs : 'b list) :
let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e)
(** Infers the most permissive type from an expression *)
let rec typecheck_expr_bottom_up
(ctx : A.decl_ctx)
(env : 'm Ast.expr env)
(e : 'm Ast.expr) : (A.dcalc, mark) A.gexpr Bindlib.box =
let rec typecheck_expr_bottom_up :
type a.
A.decl_ctx ->
(a, 'm A.mark) A.gexpr env ->
(a, 'm A.mark) A.gexpr ->
(a, mark) A.gexpr A.box =
fun ctx env e ->
(* Cli.debug_format "Looking for type of %a" (Expr.format ~debug:true ctx)
e; *)
let pos_e = A.Expr.pos e in
let mark (e : (A.dcalc, mark) A.naked_gexpr) uf =
Marked.mark { uf; pos = pos_e } e
in
let pos_e = Expr.pos e in
let mark e uf = Marked.mark { uf; pos = pos_e } e in
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
let mark_with_uf e1 ?pos ty = mark e1 (unionfind_make ?pos ty) in
match Marked.unmark e with
| A.ELocation _ -> assert false
| A.EStruct _ -> assert false
| A.EStructAccess _ -> assert false
| A.EEnumInj _ -> assert false
| A.EMatchS _ -> assert false
| A.ERaise _ -> assert false
| A.ECatch _ -> assert false
| A.EVar v -> begin
match A.Var.Map.find_opt v env with
match Var.Map.find_opt v env with
| Some t ->
let+ v' = Bindlib.box_var (A.Var.translate v) in
let+ v' = Bindlib.box_var (Var.translate v) in
mark v' t
| None ->
Errors.raise_spanned_error (A.Expr.pos e)
Errors.raise_spanned_error (Expr.pos e)
"Variable %s not found in the current context." (Bindlib.name_of v)
end
| A.ELit (LBool _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TBool)
| A.ELit (LInt _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TInt)
| A.ELit (LRat _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TRat)
| A.ELit (LMoney _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TMoney)
| A.ELit (LDate _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TDate)
| A.ELit (LDuration _) as e1 ->
Bindlib.box @@ mark_with_uf e1 (TLit TDuration)
| A.ELit LUnit as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TUnit)
| A.ELit LEmptyError as e1 ->
Bindlib.box @@ mark_with_uf e1 (TAny (Any.fresh ()))
| A.ELit lit as e1 -> Bindlib.box @@ mark_with_uf e1 (lit_type lit)
| A.ETuple (es, None) ->
let+ es = bmap (typecheck_expr_bottom_up ctx env) es in
mark_with_uf (ETuple (es, None)) (TTuple (List.map ty es))
mark_with_uf (A.ETuple (es, None)) (TTuple (List.map ty es))
| A.ETuple (es, Some s_name) ->
let tys =
List.map
@ -335,13 +345,13 @@ let rec typecheck_expr_bottom_up
(A.StructMap.find s_name ctx.A.ctx_structs)
in
let+ es = bmap2 (typecheck_expr_top_down ctx env) tys es in
mark_with_uf (ETuple (es, Some s_name)) (TStruct s_name)
mark_with_uf (A.ETuple (es, Some s_name)) (TStruct s_name)
| A.ETupleAccess (e1, n, s, typs) -> begin
let utyps = List.map ast_to_typ typs in
let tuple_ty = match s with None -> TTuple utyps | Some s -> TStruct s in
let+ e1 = typecheck_expr_top_down ctx env (unionfind_make tuple_ty) e1 in
match List.nth_opt utyps n with
| Some t' -> mark (ETupleAccess (e1, n, s, typs)) t'
| Some t' -> mark (A.ETupleAccess (e1, n, s, typs)) t'
| None ->
Errors.raise_spanned_error (Marked.get_mark e1).pos
"Expression should have a tuple type with at least %d elements but \
@ -354,7 +364,7 @@ let rec typecheck_expr_bottom_up
match List.nth_opt ts' n with
| Some ts_n -> ts_n
| None ->
Errors.raise_spanned_error (A.Expr.pos e)
Errors.raise_spanned_error (Expr.pos e)
"Expression should have a sum type with at least %d cases but only \
has %d"
n (List.length ts')
@ -376,19 +386,19 @@ let rec typecheck_expr_bottom_up
es')
es enum_cases
in
mark (EMatch (e1', es', e_name)) t_ret
mark (A.EMatch (e1', es', e_name)) t_ret
| A.EAbs (binder, taus) ->
if Bindlib.mbinder_arity binder <> List.length taus then
Errors.raise_spanned_error (A.Expr.pos e)
Errors.raise_spanned_error (Expr.pos e)
"function has %d variables but was supplied %d types"
(Bindlib.mbinder_arity binder)
(List.length taus)
else
let xs, body = Bindlib.unmbind binder in
let xs' = Array.map A.Var.translate xs in
let xs' = Array.map Var.translate xs in
let xstaus = List.mapi (fun i tau -> xs.(i), ast_to_typ tau) taus in
let env =
List.fold_left (fun env (x, tau) -> A.Var.Map.add x tau env) env xstaus
List.fold_left (fun env (x, tau) -> Var.Map.add x tau env) env xstaus
in
let body' = typecheck_expr_bottom_up ctx env body in
let t_func =
@ -397,7 +407,7 @@ let rec typecheck_expr_bottom_up
xstaus (box_ty body')
in
let+ binder' = Bindlib.bind_mvar xs' body' in
mark (EAbs (binder', taus)) t_func
mark (A.EAbs (binder', taus)) t_func
| A.EApp (e1, args) ->
let args' = bmap (typecheck_expr_bottom_up ctx env) args in
let t_ret = unionfind_make (TAny (Any.fresh ())) in
@ -409,7 +419,7 @@ let rec typecheck_expr_bottom_up
in
let+ e1' = typecheck_expr_bottom_up ctx env e1 and+ args' in
unify ctx e (ty e1') t_func;
mark (EApp (e1', args')) t_ret
mark (A.EApp (e1', args')) t_ret
| A.EOp op as e1 -> Bindlib.box @@ mark e1 (op_type (Marked.mark pos_e op))
| A.EDefault (excepts, just, cons) ->
let just' =
@ -456,46 +466,42 @@ let rec typecheck_expr_bottom_up
mark_with_uf (A.EArray es') (TArray cell_type)
(** Checks whether the expression can be typed with the provided type *)
and typecheck_expr_top_down
(ctx : A.decl_ctx)
(env : 'm Ast.expr env)
(tau : unionfind_typ)
(e : 'm Ast.expr) : (A.dcalc, mark) A.gexpr Bindlib.box =
and typecheck_expr_top_down :
type a.
A.decl_ctx ->
(a, 'm A.mark) A.gexpr env ->
unionfind_typ ->
(a, 'm A.mark) A.gexpr ->
(a, mark) A.gexpr Bindlib.box =
fun ctx env tau e ->
(* Cli.debug_format "Propagating type %a for naked_expr %a" (format_typ ctx)
tau (Expr.format ctx) e; *)
let pos_e = A.Expr.pos e in
let pos_e = Expr.pos e in
let mark e = Marked.mark { uf = tau; pos = pos_e } e in
let unify_and_mark (e' : (A.dcalc, mark) A.naked_gexpr) tau' =
let unify_and_mark (e' : (a, mark) A.naked_gexpr) tau' =
unify ctx e tau' tau;
Marked.mark { uf = tau; pos = pos_e } e'
in
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
match Marked.unmark e with
| A.ELocation _ -> assert false
| A.EStruct _ -> assert false
| A.EStructAccess _ -> assert false
| A.EEnumInj _ -> assert false
| A.EMatchS _ -> assert false
| A.ERaise _ -> assert false
| A.ECatch _ -> assert false
| A.EVar v -> begin
match A.Var.Map.find_opt v env with
match Var.Map.find_opt v env with
| Some tau' ->
let+ v' = Bindlib.box_var (A.Var.translate v) in
let+ v' = Bindlib.box_var (Var.translate v) in
unify_and_mark v' tau'
| None ->
Errors.raise_spanned_error pos_e
"Variable %s not found in the current context" (Bindlib.name_of v)
end
| A.ELit (LBool _) as e1 ->
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TBool))
| A.ELit (LInt _) as e1 ->
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TInt))
| A.ELit (LRat _) as e1 ->
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TRat))
| A.ELit (LMoney _) as e1 ->
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TMoney))
| A.ELit (LDate _) as e1 ->
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TDate))
| A.ELit (LDuration _) as e1 ->
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TDuration))
| A.ELit LUnit as e1 ->
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TUnit))
| A.ELit LEmptyError as e1 ->
Bindlib.box @@ unify_and_mark e1 (unionfind_make (TAny (Any.fresh ())))
| A.ELit lit as e1 ->
Bindlib.box @@ unify_and_mark e1 (unionfind_make (lit_type lit))
| A.ETuple (es, None) ->
let+ es' = bmap (typecheck_expr_bottom_up ctx env) es in
unify_and_mark
@ -518,7 +524,7 @@ and typecheck_expr_top_down
match List.nth_opt typs' n with
| Some t1n -> unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n
| None ->
Errors.raise_spanned_error (A.Expr.pos e1)
Errors.raise_spanned_error (Expr.pos e1)
"Expression should have a tuple type with at least %d elements but \
only has %d"
n (List.length typs)
@ -529,7 +535,7 @@ and typecheck_expr_top_down
match List.nth_opt ts' n with
| Some ts_n -> ts_n
| None ->
Errors.raise_spanned_error (A.Expr.pos e)
Errors.raise_spanned_error (Expr.pos e)
"Expression should have a sum type with at least %d cases but only \
has %d"
n (List.length ts)
@ -556,19 +562,19 @@ and typecheck_expr_top_down
unify_and_mark (EMatch (e1', es', e_name)) t_ret
| A.EAbs (binder, t_args) ->
if Bindlib.mbinder_arity binder <> List.length t_args then
Errors.raise_spanned_error (A.Expr.pos e)
Errors.raise_spanned_error (Expr.pos e)
"function has %d variables but was supplied %d types"
(Bindlib.mbinder_arity binder)
(List.length t_args)
else
let xs, body = Bindlib.unmbind binder in
let xs' = Array.map A.Var.translate xs in
let xs' = Array.map Var.translate xs in
let xstaus =
List.map2 (fun x t_arg -> x, ast_to_typ t_arg) (Array.to_list xs) t_args
in
let env =
List.fold_left
(fun env (x, t_arg) -> A.Var.Map.add x t_arg env)
(fun env (x, t_arg) -> Var.Map.add x t_arg env)
env xstaus
in
let body' = typecheck_expr_bottom_up ctx env body in
@ -643,21 +649,21 @@ let wrap ctx f e =
let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
(* Infer the type of an expression *)
let infer_types (ctx : A.decl_ctx) (e : 'm Ast.expr) :
A.typed Ast.expr Bindlib.box =
A.Expr.map_marks ~f:get_ty_mark
@@ wrap ctx (typecheck_expr_bottom_up ctx A.Var.Map.empty) e
let infer_types (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) :
(a, A.typed A.mark) A.gexpr A.box =
Expr.map_marks ~f:get_ty_mark
@@ wrap ctx (typecheck_expr_bottom_up ctx Var.Map.empty) e
let infer_type (type m) ctx (e : m Ast.expr) =
let infer_type (type a m) ctx (e : (a, m A.mark) A.gexpr) =
match Marked.get_mark e with
| A.Typed { ty; _ } -> ty
| A.Untyped _ -> A.Expr.ty (Bindlib.unbox (infer_types ctx e))
| A.Untyped _ -> Expr.ty (Bindlib.unbox (infer_types ctx e))
(** Typechecks an expression given an expected type *)
let check_type (ctx : A.decl_ctx) (e : 'm Ast.expr) (tau : A.typ) =
let check_type (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) (tau : A.typ) =
(* todo: consider using the already inferred type if ['m] = [typed] *)
ignore
@@ wrap ctx (typecheck_expr_top_down ctx A.Var.Map.empty (ast_to_typ tau)) e
@@ wrap ctx (typecheck_expr_top_down ctx Var.Map.empty (ast_to_typ tau)) e
let infer_types_program prg =
let ctx = prg.A.decl_ctx in
@ -686,7 +692,7 @@ let infer_types_program prg =
let rec process_scope_body_expr env = function
| A.Result e ->
let e' = wrap ctx (typecheck_expr_top_down ctx env ty_out) e in
let e' = A.Expr.map_marks ~f:get_ty_mark e' in
let e' = Expr.map_marks ~f:get_ty_mark e' in
Bindlib.box_apply (fun e -> A.Result e) e'
| A.ScopeLet
{
@ -704,9 +710,9 @@ let infer_types_program prg =
[unify] parameters, which keeps location of the type as defined
instead of as inferred. *)
let var, next = Bindlib.unbind scope_let_next in
let env = A.Var.Map.add var ty_e env in
let env = Var.Map.add var ty_e env in
let next = process_scope_body_expr env next in
let scope_let_next = Bindlib.bind_var (A.Var.translate var) next in
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
Bindlib.box_apply2
(fun scope_let_expr scope_let_next ->
A.ScopeLet
@ -717,20 +723,20 @@ let infer_types_program prg =
scope_let_next;
scope_let_pos;
})
(A.Expr.map_marks ~f:get_ty_mark e)
(Expr.map_marks ~f:get_ty_mark e)
scope_let_next
in
let scope_body_expr =
let var, e = Bindlib.unbind body in
let env = A.Var.Map.add var ty_in env in
let env = Var.Map.add var ty_in env in
let e' = process_scope_body_expr env e in
Bindlib.bind_var (A.Var.translate var) e'
Bindlib.bind_var (Var.translate var) e'
in
let scope_next =
let scope_var, next = Bindlib.unbind scope_next in
let env = A.Var.Map.add scope_var ty_scope env in
let env = Var.Map.add scope_var ty_scope env in
let next' = process_scopes env next in
Bindlib.bind_var (A.Var.translate scope_var) next'
Bindlib.bind_var (Var.translate scope_var) next'
in
Bindlib.box_apply2
(fun scope_body_expr scope_next ->
@ -747,5 +753,5 @@ let infer_types_program prg =
})
scope_body_expr scope_next
in
let scopes = Bindlib.unbox (process_scopes A.Var.Map.empty prg.scopes) in
let scopes = Bindlib.unbox (process_scopes Var.Map.empty prg.scopes) in
{ A.decl_ctx = ctx; scopes }

View File

@ -17,15 +17,18 @@
(** Typing for the default calculus. Because of the error terms, we perform type
inference using the classical W algorithm with union-find unification. *)
open Shared_ast
open Definitions
val infer_types : decl_ctx -> untyped Ast.expr -> typed Ast.expr Bindlib.box
val infer_types :
decl_ctx -> ('a, untyped mark) gexpr -> ('a, typed mark) gexpr box
(** Infers types everywhere on the given expression, and adds (or replaces) type
annotations on each node *)
val infer_type : decl_ctx -> 'm Ast.expr -> typ
val infer_type : decl_ctx -> ('a, 'm mark) gexpr -> typ
(** Gets the outer type of the given expression, using either the existing
annotations or inference *)
val check_type : decl_ctx -> 'm Ast.expr -> typ -> unit
val infer_types_program : untyped Ast.program -> typed Ast.program
val check_type : decl_ctx -> ('a, 'm mark) gexpr -> typ -> unit
val infer_types_program :
('a, untyped mark) gexpr program -> ('a, typed mark) gexpr program