diff --git a/compiler/driver.ml b/compiler/driver.ml index 69873195..41260fb4 100644 --- a/compiler/driver.ml +++ b/compiler/driver.ml @@ -218,7 +218,7 @@ let driver source_file (options : Cli.options) : int = | ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc | `Proof | `Plugin _ ) as backend -> ( Cli.debug_print "Typechecking..."; - let prgm = Shared_ast.Typing.infer_types_program prgm in + let prgm = Shared_ast.Typing.program prgm in (* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a" (Print.typ prgm.decl_ctx) typ); *) match backend with diff --git a/compiler/shared_ast/typing.ml b/compiler/shared_ast/typing.ml index 68968471..9478b44f 100644 --- a/compiler/shared_ast/typing.ml +++ b/compiler/shared_ast/typing.ml @@ -786,108 +786,95 @@ let wrap ctx f e = let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos } (* Infer the type of an expression *) -let infer_types (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) : - (a, A.typed A.mark) A.gexpr A.box = - Expr.map_marks ~f:get_ty_mark - @@ wrap ctx (typecheck_expr_bottom_up ctx empty_env) e - -let infer_type (type a m) ctx (e : (a, m A.mark) A.gexpr) = - match Marked.get_mark e with - | A.Typed { ty; _ } -> ty - | A.Untyped _ -> Expr.ty (Bindlib.unbox (infer_types ctx e)) - -(** Typechecks an expression given an expected type *) -let check_type (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) (tau : A.typ) = - (* todo: consider using the already inferred type if ['m] = [typed] *) - ignore @@ wrap ctx (typecheck_expr_top_down ctx empty_env (ast_to_typ tau)) e - -let infer_types_program prg = - let ctx = prg.A.decl_ctx in - let rec process_scopes env = function - | A.Nil -> Bindlib.box A.Nil - | A.ScopeDef - { - scope_next; - scope_name; - scope_body = - { - scope_body_input_struct = s_in; - scope_body_output_struct = s_out; - scope_body_expr = body; - }; - } -> - let scope_pos = Marked.get_mark (A.ScopeName.get_info scope_name) in - let struct_ty struct_name = - UnionFind.make (Marked.mark scope_pos (TStruct struct_name)) - in - let ty_in = struct_ty s_in in - let ty_out = struct_ty s_out in - let ty_scope = - UnionFind.make (Marked.mark scope_pos (TArrow (ty_in, ty_out))) - in - let rec process_scope_body_expr env = function - | A.Result e -> - let e' = wrap ctx (typecheck_expr_top_down ctx env ty_out) e in - let e' = Expr.map_marks ~f:get_ty_mark e' in - Bindlib.box_apply (fun e -> A.Result e) e' - | A.ScopeLet - { - scope_let_kind; - scope_let_typ; - scope_let_expr = e0; - scope_let_next; - scope_let_pos; - } -> - let ty_e = ast_to_typ scope_let_typ in - let e = wrap ctx (typecheck_expr_bottom_up ctx env) e0 in - wrap ctx (fun t -> Bindlib.box (unify ctx e0 (ty e) t)) ty_e; - (* We could use [typecheck_expr_top_down] rather than this manual - unification, but we get better messages with this order of the - [unify] parameters, which keeps location of the type as defined - instead of as inferred. *) - let var, next = Bindlib.unbind scope_let_next in - let env = { env with vars = Var.Map.add var ty_e env.vars } in - let next = process_scope_body_expr env next in - let scope_let_next = Bindlib.bind_var (Var.translate var) next in - Bindlib.box_apply2 - (fun scope_let_expr scope_let_next -> - A.ScopeLet - { - scope_let_kind; - scope_let_typ; - scope_let_expr; - scope_let_next; - scope_let_pos; - }) - (Expr.map_marks ~f:get_ty_mark e) - scope_let_next - in - let scope_body_expr = - let var, e = Bindlib.unbind body in - let env = { env with vars = Var.Map.add var ty_in env.vars } in - let e' = process_scope_body_expr env e in - Bindlib.bind_var (Var.translate var) e' - in - let scope_next = - let scope_var, next = Bindlib.unbind scope_next in - let env = { env with vars = Var.Map.add scope_var ty_scope env.vars } in - let next' = process_scopes env next in - Bindlib.bind_var (Var.translate scope_var) next' - in - Bindlib.box_apply2 - (fun scope_body_expr scope_next -> - A.ScopeDef - { - scope_next; - scope_name; - scope_body = - { - scope_body_input_struct = s_in; - scope_body_output_struct = s_out; - scope_body_expr; - }; - }) - scope_body_expr scope_next +let expr + (type a) + (ctx : A.decl_ctx) + ?(env = empty_env) + ?(typ : A.typ option) + (e : (a, 'm) A.gexpr) : (a, A.typed A.mark) A.gexpr A.box = + let fty = + match typ with + | None -> typecheck_expr_bottom_up ctx env + | Some typ -> typecheck_expr_top_down ctx env (ast_to_typ typ) in - let scopes = Bindlib.unbox (process_scopes empty_env prg.scopes) in - { A.decl_ctx = ctx; scopes } + Expr.map_marks ~f:get_ty_mark (wrap ctx fty e) + +let rec scope_body_expr ctx env ty_out body_expr = + match body_expr with + | A.Result e -> + let e' = wrap ctx (typecheck_expr_top_down ctx env ty_out) e in + let e' = Expr.map_marks ~f:get_ty_mark e' in + Bindlib.box_apply (fun e -> A.Result e) e' + | A.ScopeLet + { + scope_let_kind; + scope_let_typ; + scope_let_expr = e0; + scope_let_next; + scope_let_pos; + } -> + let ty_e = ast_to_typ scope_let_typ in + let e = wrap ctx (typecheck_expr_bottom_up ctx env) e0 in + wrap ctx (fun t -> Bindlib.box (unify ctx e0 (ty e) t)) ty_e; + (* We could use [typecheck_expr_top_down] rather than this manual + unification, but we get better messages with this order of the [unify] + parameters, which keeps location of the type as defined instead of as + inferred. *) + let var, next = Bindlib.unbind scope_let_next in + let env = { env with vars = Var.Map.add var ty_e env.vars } in + let next = scope_body_expr ctx env ty_out next in + let scope_let_next = Bindlib.bind_var (Var.translate var) next in + Bindlib.box_apply2 + (fun scope_let_expr scope_let_next -> + A.ScopeLet + { + scope_let_kind; + scope_let_typ; + scope_let_expr; + scope_let_next; + scope_let_pos; + }) + (Expr.map_marks ~f:get_ty_mark e) + scope_let_next + +let scope_body ctx env body = + let get_pos struct_name = + Marked.get_mark (A.StructName.get_info struct_name) + in + let struct_ty struct_name = + UnionFind.make (Marked.mark (get_pos struct_name) (TStruct struct_name)) + in + let ty_in = struct_ty body.A.scope_body_input_struct in + let ty_out = struct_ty body.A.scope_body_output_struct in + let var, e = Bindlib.unbind body.A.scope_body_expr in + let env = { env with vars = Var.Map.add var ty_in env.vars } in + let e' = scope_body_expr ctx env ty_out e in + ( Bindlib.bind_var (Var.translate var) e', + UnionFind.make + (Marked.mark + (get_pos body.A.scope_body_output_struct) + (TArrow (ty_in, ty_out))) ) + +let rec scopes ctx env = function + | A.Nil -> Bindlib.box A.Nil + | A.ScopeDef def -> + let body_e, ty_scope = scope_body ctx env def.scope_body in + let scope_next = + let scope_var, next = Bindlib.unbind def.scope_next in + let env = { env with vars = Var.Map.add scope_var ty_scope env.vars } in + let next' = scopes ctx env next in + Bindlib.bind_var (Var.translate scope_var) next' + in + Bindlib.box_apply2 + (fun scope_body_expr scope_next -> + A.ScopeDef + { + def with + scope_body = { def.scope_body with scope_body_expr }; + scope_next; + }) + body_e scope_next + +let program prg = + let scopes = Bindlib.unbox (scopes prg.A.decl_ctx empty_env prg.A.scopes) in + { prg with scopes } diff --git a/compiler/shared_ast/typing.mli b/compiler/shared_ast/typing.mli index feb77127..582dca58 100644 --- a/compiler/shared_ast/typing.mli +++ b/compiler/shared_ast/typing.mli @@ -19,16 +19,16 @@ open Definitions -val infer_types : - decl_ctx -> ('a, untyped mark) gexpr -> ('a, typed mark) gexpr box -(** Infers types everywhere on the given expression, and adds (or replaces) type - annotations on each node *) +type 'e env -val infer_type : decl_ctx -> ('a, 'm mark) gexpr -> typ -(** Gets the outer type of the given expression, using either the existing - annotations or inference *) +val expr : + decl_ctx -> + ?env:'e env -> + ?typ:typ -> + (('a, 'm mark) gexpr as 'e) -> + ('a, typed mark) gexpr box +(** Infers and marks the types for the given expression. If [typ] is provided, + it is assumed to be the outer type and used for inference top-down. *) -val check_type : decl_ctx -> ('a, 'm mark) gexpr -> typ -> unit - -val infer_types_program : - ('a, untyped mark) gexpr program -> ('a, typed mark) gexpr program +val program : ('a, untyped mark) gexpr program -> ('a, typed mark) gexpr program +(** Typing on whole programs *)