diff --git a/compiler/driver.ml b/compiler/driver.ml index a8edcc36..ee41a9e6 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -219,7 +219,7 @@ let driver source_file (options : Cli.options) : int = | ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc | `Proof | `Plugin _ ) as backend -> ( Cli.debug_print "Typechecking..."; - let prgm = Dcalc.Typing.infer_types_program prgm in + let prgm = Shared_ast.Typing.infer_types_program prgm in (* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a" (Print.typ prgm.decl_ctx) typ); *) match backend with diff --git a/compiler/shared_ast/definitions.ml b/compiler/shared_ast/definitions.ml index 8d1e3d4e..4f31dbff 100644 --- a/compiler/shared_ast/definitions.ml +++ b/compiler/shared_ast/definitions.ml @@ -264,7 +264,7 @@ type typed = { pos : Pos.t; ty : typ } type _ mark = Untyped : untyped -> untyped mark | Typed : typed -> typed mark (** Useful for errors and printing, for example *) -type any_expr = AnyExpr : (_ any, _ mark) gexpr -> any_expr +type any_expr = AnyExpr : (_, _ mark) gexpr -> any_expr (** {2 Higher-level program structure} *) diff --git a/compiler/shared_ast/shared_ast.ml b/compiler/shared_ast/shared_ast.ml index 5d858f08..1a170742 100644 --- a/compiler/shared_ast/shared_ast.ml +++ b/compiler/shared_ast/shared_ast.ml @@ -20,3 +20,4 @@ module Expr = Expr module Scope = Scope module Program = Program module Print = Print +module Typing = Typing diff --git a/compiler/dcalc/typing.ml b/compiler/shared_ast/typing.ml similarity index 85% rename from compiler/dcalc/typing.ml rename to compiler/shared_ast/typing.ml index 140ae8ad..304b6c84 100644 --- a/compiler/dcalc/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -18,7 +18,7 @@ inference using the classical W algorithm with union-find unification. *) open Utils -module A = Shared_ast +module A = Definitions module Any = Utils.Uid.Make @@ -33,8 +33,8 @@ module Any = () type unionfind_typ = naked_typ Marked.pos UnionFind.elem -(** We do not reuse {!type: Dcalc.Ast.naked_typ} because we have to include a - new [TAny] variant. Indeed, error terms can have any type and this has to be +(** We do not reuse {!type: Shared_ast.typ} because we have to include a new + [TAny] variant. Indeed, error terms can have any type and this has to be captured by the type sytem. *) and naked_typ = @@ -90,7 +90,7 @@ let rec format_typ in let naked_typ = UnionFind.get (UnionFind.find naked_typ) in match Marked.unmark naked_typ with - | TLit l -> Format.fprintf fmt "%a" A.Print.tlit l + | TLit l -> Format.fprintf fmt "%a" Print.tlit l | TTuple ts -> Format.fprintf fmt "@[(%a)]" (Format.pp_print_list @@ -195,6 +195,17 @@ let handle_type_error ctx e t1 t2 = (Cli.format_with_style [ANSITerminal.blue; ANSITerminal.Bold]) "-->" t2_s () +let lit_type (type a) (lit : a A.glit) : naked_typ = + match lit with + | LBool _ -> TLit TBool + | LInt _ -> TLit TInt + | LRat _ -> TLit TRat + | LMoney _ -> TLit TMoney + | LDate _ -> TLit TDate + | LDuration _ -> TLit TDuration + | LUnit -> TLit TUnit + | LEmptyError -> TAny (Any.fresh ()) + (** Operators have a single type, instead of being polymorphic with constraints. This allows us to have a simpler type system, while we argue the syntactic burden of operator annotations helps the programmer visualize the type flow @@ -265,9 +276,9 @@ let op_type (op : A.operator Marked.pos) : unionfind_typ = (** {1 Double-directed typing} *) -type 'e env = ('e, unionfind_typ) A.Var.Map.t +type 'e env = ('e, unionfind_typ) Var.Map.t -let add_pos e ty = Marked.mark (A.Expr.pos e) ty +let add_pos e ty = Marked.mark (Expr.pos e) ty let ty (_, { uf; _ }) = uf let ( let+ ) x f = Bindlib.box_apply f x let ( and+ ) x1 x2 = Bindlib.box_pair x1 x2 @@ -293,41 +304,40 @@ let bmap2 (f : 'a -> 'b -> 'c Bindlib.box) (es : 'a list) (xs : 'b list) : let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e) (** Infers the most permissive type from an expression *) -let rec typecheck_expr_bottom_up - (ctx : A.decl_ctx) - (env : 'm Ast.expr env) - (e : 'm Ast.expr) : (A.dcalc, mark) A.gexpr Bindlib.box = +let rec typecheck_expr_bottom_up : + type a. + A.decl_ctx -> + (a, 'm A.mark) A.gexpr env -> + (a, 'm A.mark) A.gexpr -> + (a, mark) A.gexpr A.box = + fun ctx env e -> (* Cli.debug_format "Looking for type of %a" (Expr.format ~debug:true ctx) e; *) - let pos_e = A.Expr.pos e in - let mark (e : (A.dcalc, mark) A.naked_gexpr) uf = - Marked.mark { uf; pos = pos_e } e - in + let pos_e = Expr.pos e in + let mark e uf = Marked.mark { uf; pos = pos_e } e in let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in let mark_with_uf e1 ?pos ty = mark e1 (unionfind_make ?pos ty) in match Marked.unmark e with + | A.ELocation _ -> assert false + | A.EStruct _ -> assert false + | A.EStructAccess _ -> assert false + | A.EEnumInj _ -> assert false + | A.EMatchS _ -> assert false + | A.ERaise _ -> assert false + | A.ECatch _ -> assert false | A.EVar v -> begin - match A.Var.Map.find_opt v env with + match Var.Map.find_opt v env with | Some t -> - let+ v' = Bindlib.box_var (A.Var.translate v) in + let+ v' = Bindlib.box_var (Var.translate v) in mark v' t | None -> - Errors.raise_spanned_error (A.Expr.pos e) + Errors.raise_spanned_error (Expr.pos e) "Variable %s not found in the current context." (Bindlib.name_of v) end - | A.ELit (LBool _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TBool) - | A.ELit (LInt _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TInt) - | A.ELit (LRat _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TRat) - | A.ELit (LMoney _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TMoney) - | A.ELit (LDate _) as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TDate) - | A.ELit (LDuration _) as e1 -> - Bindlib.box @@ mark_with_uf e1 (TLit TDuration) - | A.ELit LUnit as e1 -> Bindlib.box @@ mark_with_uf e1 (TLit TUnit) - | A.ELit LEmptyError as e1 -> - Bindlib.box @@ mark_with_uf e1 (TAny (Any.fresh ())) + | A.ELit lit as e1 -> Bindlib.box @@ mark_with_uf e1 (lit_type lit) | A.ETuple (es, None) -> let+ es = bmap (typecheck_expr_bottom_up ctx env) es in - mark_with_uf (ETuple (es, None)) (TTuple (List.map ty es)) + mark_with_uf (A.ETuple (es, None)) (TTuple (List.map ty es)) | A.ETuple (es, Some s_name) -> let tys = List.map @@ -335,13 +345,13 @@ let rec typecheck_expr_bottom_up (A.StructMap.find s_name ctx.A.ctx_structs) in let+ es = bmap2 (typecheck_expr_top_down ctx env) tys es in - mark_with_uf (ETuple (es, Some s_name)) (TStruct s_name) + mark_with_uf (A.ETuple (es, Some s_name)) (TStruct s_name) | A.ETupleAccess (e1, n, s, typs) -> begin let utyps = List.map ast_to_typ typs in let tuple_ty = match s with None -> TTuple utyps | Some s -> TStruct s in let+ e1 = typecheck_expr_top_down ctx env (unionfind_make tuple_ty) e1 in match List.nth_opt utyps n with - | Some t' -> mark (ETupleAccess (e1, n, s, typs)) t' + | Some t' -> mark (A.ETupleAccess (e1, n, s, typs)) t' | None -> Errors.raise_spanned_error (Marked.get_mark e1).pos "Expression should have a tuple type with at least %d elements but \ @@ -354,7 +364,7 @@ let rec typecheck_expr_bottom_up match List.nth_opt ts' n with | Some ts_n -> ts_n | None -> - Errors.raise_spanned_error (A.Expr.pos e) + Errors.raise_spanned_error (Expr.pos e) "Expression should have a sum type with at least %d cases but only \ has %d" n (List.length ts') @@ -376,19 +386,19 @@ let rec typecheck_expr_bottom_up es') es enum_cases in - mark (EMatch (e1', es', e_name)) t_ret + mark (A.EMatch (e1', es', e_name)) t_ret | A.EAbs (binder, taus) -> if Bindlib.mbinder_arity binder <> List.length taus then - Errors.raise_spanned_error (A.Expr.pos e) + Errors.raise_spanned_error (Expr.pos e) "function has %d variables but was supplied %d types" (Bindlib.mbinder_arity binder) (List.length taus) else let xs, body = Bindlib.unmbind binder in - let xs' = Array.map A.Var.translate xs in + let xs' = Array.map Var.translate xs in let xstaus = List.mapi (fun i tau -> xs.(i), ast_to_typ tau) taus in let env = - List.fold_left (fun env (x, tau) -> A.Var.Map.add x tau env) env xstaus + List.fold_left (fun env (x, tau) -> Var.Map.add x tau env) env xstaus in let body' = typecheck_expr_bottom_up ctx env body in let t_func = @@ -397,7 +407,7 @@ let rec typecheck_expr_bottom_up xstaus (box_ty body') in let+ binder' = Bindlib.bind_mvar xs' body' in - mark (EAbs (binder', taus)) t_func + mark (A.EAbs (binder', taus)) t_func | A.EApp (e1, args) -> let args' = bmap (typecheck_expr_bottom_up ctx env) args in let t_ret = unionfind_make (TAny (Any.fresh ())) in @@ -409,7 +419,7 @@ let rec typecheck_expr_bottom_up in let+ e1' = typecheck_expr_bottom_up ctx env e1 and+ args' in unify ctx e (ty e1') t_func; - mark (EApp (e1', args')) t_ret + mark (A.EApp (e1', args')) t_ret | A.EOp op as e1 -> Bindlib.box @@ mark e1 (op_type (Marked.mark pos_e op)) | A.EDefault (excepts, just, cons) -> let just' = @@ -456,46 +466,42 @@ let rec typecheck_expr_bottom_up mark_with_uf (A.EArray es') (TArray cell_type) (** Checks whether the expression can be typed with the provided type *) -and typecheck_expr_top_down - (ctx : A.decl_ctx) - (env : 'm Ast.expr env) - (tau : unionfind_typ) - (e : 'm Ast.expr) : (A.dcalc, mark) A.gexpr Bindlib.box = +and typecheck_expr_top_down : + type a. + A.decl_ctx -> + (a, 'm A.mark) A.gexpr env -> + unionfind_typ -> + (a, 'm A.mark) A.gexpr -> + (a, mark) A.gexpr Bindlib.box = + fun ctx env tau e -> (* Cli.debug_format "Propagating type %a for naked_expr %a" (format_typ ctx) tau (Expr.format ctx) e; *) - let pos_e = A.Expr.pos e in + let pos_e = Expr.pos e in let mark e = Marked.mark { uf = tau; pos = pos_e } e in - let unify_and_mark (e' : (A.dcalc, mark) A.naked_gexpr) tau' = + let unify_and_mark (e' : (a, mark) A.naked_gexpr) tau' = unify ctx e tau' tau; Marked.mark { uf = tau; pos = pos_e } e' in let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in match Marked.unmark e with + | A.ELocation _ -> assert false + | A.EStruct _ -> assert false + | A.EStructAccess _ -> assert false + | A.EEnumInj _ -> assert false + | A.EMatchS _ -> assert false + | A.ERaise _ -> assert false + | A.ECatch _ -> assert false | A.EVar v -> begin - match A.Var.Map.find_opt v env with + match Var.Map.find_opt v env with | Some tau' -> - let+ v' = Bindlib.box_var (A.Var.translate v) in + let+ v' = Bindlib.box_var (Var.translate v) in unify_and_mark v' tau' | None -> Errors.raise_spanned_error pos_e "Variable %s not found in the current context" (Bindlib.name_of v) end - | A.ELit (LBool _) as e1 -> - Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TBool)) - | A.ELit (LInt _) as e1 -> - Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TInt)) - | A.ELit (LRat _) as e1 -> - Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TRat)) - | A.ELit (LMoney _) as e1 -> - Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TMoney)) - | A.ELit (LDate _) as e1 -> - Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TDate)) - | A.ELit (LDuration _) as e1 -> - Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TDuration)) - | A.ELit LUnit as e1 -> - Bindlib.box @@ unify_and_mark e1 (unionfind_make (TLit TUnit)) - | A.ELit LEmptyError as e1 -> - Bindlib.box @@ unify_and_mark e1 (unionfind_make (TAny (Any.fresh ()))) + | A.ELit lit as e1 -> + Bindlib.box @@ unify_and_mark e1 (unionfind_make (lit_type lit)) | A.ETuple (es, None) -> let+ es' = bmap (typecheck_expr_bottom_up ctx env) es in unify_and_mark @@ -518,7 +524,7 @@ and typecheck_expr_top_down match List.nth_opt typs' n with | Some t1n -> unify_and_mark (A.ETupleAccess (e1', n, s, typs)) t1n | None -> - Errors.raise_spanned_error (A.Expr.pos e1) + Errors.raise_spanned_error (Expr.pos e1) "Expression should have a tuple type with at least %d elements but \ only has %d" n (List.length typs) @@ -529,7 +535,7 @@ and typecheck_expr_top_down match List.nth_opt ts' n with | Some ts_n -> ts_n | None -> - Errors.raise_spanned_error (A.Expr.pos e) + Errors.raise_spanned_error (Expr.pos e) "Expression should have a sum type with at least %d cases but only \ has %d" n (List.length ts) @@ -556,19 +562,19 @@ and typecheck_expr_top_down unify_and_mark (EMatch (e1', es', e_name)) t_ret | A.EAbs (binder, t_args) -> if Bindlib.mbinder_arity binder <> List.length t_args then - Errors.raise_spanned_error (A.Expr.pos e) + Errors.raise_spanned_error (Expr.pos e) "function has %d variables but was supplied %d types" (Bindlib.mbinder_arity binder) (List.length t_args) else let xs, body = Bindlib.unmbind binder in - let xs' = Array.map A.Var.translate xs in + let xs' = Array.map Var.translate xs in let xstaus = List.map2 (fun x t_arg -> x, ast_to_typ t_arg) (Array.to_list xs) t_args in let env = List.fold_left - (fun env (x, t_arg) -> A.Var.Map.add x t_arg env) + (fun env (x, t_arg) -> Var.Map.add x t_arg env) env xstaus in let body' = typecheck_expr_bottom_up ctx env body in @@ -643,21 +649,21 @@ let wrap ctx f e = let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos } (* Infer the type of an expression *) -let infer_types (ctx : A.decl_ctx) (e : 'm Ast.expr) : - A.typed Ast.expr Bindlib.box = - A.Expr.map_marks ~f:get_ty_mark - @@ wrap ctx (typecheck_expr_bottom_up ctx A.Var.Map.empty) e +let infer_types (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) : + (a, A.typed A.mark) A.gexpr A.box = + Expr.map_marks ~f:get_ty_mark + @@ wrap ctx (typecheck_expr_bottom_up ctx Var.Map.empty) e -let infer_type (type m) ctx (e : m Ast.expr) = +let infer_type (type a m) ctx (e : (a, m A.mark) A.gexpr) = match Marked.get_mark e with | A.Typed { ty; _ } -> ty - | A.Untyped _ -> A.Expr.ty (Bindlib.unbox (infer_types ctx e)) + | A.Untyped _ -> Expr.ty (Bindlib.unbox (infer_types ctx e)) (** Typechecks an expression given an expected type *) -let check_type (ctx : A.decl_ctx) (e : 'm Ast.expr) (tau : A.typ) = +let check_type (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) (tau : A.typ) = (* todo: consider using the already inferred type if ['m] = [typed] *) ignore - @@ wrap ctx (typecheck_expr_top_down ctx A.Var.Map.empty (ast_to_typ tau)) e + @@ wrap ctx (typecheck_expr_top_down ctx Var.Map.empty (ast_to_typ tau)) e let infer_types_program prg = let ctx = prg.A.decl_ctx in @@ -686,7 +692,7 @@ let infer_types_program prg = let rec process_scope_body_expr env = function | A.Result e -> let e' = wrap ctx (typecheck_expr_top_down ctx env ty_out) e in - let e' = A.Expr.map_marks ~f:get_ty_mark e' in + let e' = Expr.map_marks ~f:get_ty_mark e' in Bindlib.box_apply (fun e -> A.Result e) e' | A.ScopeLet { @@ -704,9 +710,9 @@ let infer_types_program prg = [unify] parameters, which keeps location of the type as defined instead of as inferred. *) let var, next = Bindlib.unbind scope_let_next in - let env = A.Var.Map.add var ty_e env in + let env = Var.Map.add var ty_e env in let next = process_scope_body_expr env next in - let scope_let_next = Bindlib.bind_var (A.Var.translate var) next in + let scope_let_next = Bindlib.bind_var (Var.translate var) next in Bindlib.box_apply2 (fun scope_let_expr scope_let_next -> A.ScopeLet @@ -717,20 +723,20 @@ let infer_types_program prg = scope_let_next; scope_let_pos; }) - (A.Expr.map_marks ~f:get_ty_mark e) + (Expr.map_marks ~f:get_ty_mark e) scope_let_next in let scope_body_expr = let var, e = Bindlib.unbind body in - let env = A.Var.Map.add var ty_in env in + let env = Var.Map.add var ty_in env in let e' = process_scope_body_expr env e in - Bindlib.bind_var (A.Var.translate var) e' + Bindlib.bind_var (Var.translate var) e' in let scope_next = let scope_var, next = Bindlib.unbind scope_next in - let env = A.Var.Map.add scope_var ty_scope env in + let env = Var.Map.add scope_var ty_scope env in let next' = process_scopes env next in - Bindlib.bind_var (A.Var.translate scope_var) next' + Bindlib.bind_var (Var.translate scope_var) next' in Bindlib.box_apply2 (fun scope_body_expr scope_next -> @@ -747,5 +753,5 @@ let infer_types_program prg = }) scope_body_expr scope_next in - let scopes = Bindlib.unbox (process_scopes A.Var.Map.empty prg.scopes) in + let scopes = Bindlib.unbox (process_scopes Var.Map.empty prg.scopes) in { A.decl_ctx = ctx; scopes } diff --git a/compiler/dcalc/typing.mli b/compiler/shared_ast/typing.mli similarity index 78% rename from compiler/dcalc/typing.mli rename to compiler/shared_ast/typing.mli index 401ecce4..feb77127 100644 --- a/compiler/dcalc/typing.mli +++ b/compiler/shared_ast/typing.mli @@ -17,15 +17,18 @@ (** Typing for the default calculus. Because of the error terms, we perform type inference using the classical W algorithm with union-find unification. *) -open Shared_ast +open Definitions -val infer_types : decl_ctx -> untyped Ast.expr -> typed Ast.expr Bindlib.box +val infer_types : + decl_ctx -> ('a, untyped mark) gexpr -> ('a, typed mark) gexpr box (** Infers types everywhere on the given expression, and adds (or replaces) type annotations on each node *) -val infer_type : decl_ctx -> 'm Ast.expr -> typ +val infer_type : decl_ctx -> ('a, 'm mark) gexpr -> typ (** Gets the outer type of the given expression, using either the existing annotations or inference *) -val check_type : decl_ctx -> 'm Ast.expr -> typ -> unit -val infer_types_program : untyped Ast.program -> typed Ast.program +val check_type : decl_ctx -> ('a, 'm mark) gexpr -> typ -> unit + +val infer_types_program : + ('a, untyped mark) gexpr program -> ('a, typed mark) gexpr program