Typing: simplify interface, split code in smaller functions

This commit is contained in:
Louis Gesbert 2022-09-26 12:12:39 +02:00
parent 51f79af13e
commit 05f4bb3537
3 changed files with 103 additions and 116 deletions

View File

@ -218,7 +218,7 @@ let driver source_file (options : Cli.options) : int =
| ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc
| `Proof | `Plugin _ ) as backend -> (
Cli.debug_print "Typechecking...";
let prgm = Shared_ast.Typing.infer_types_program prgm in
let prgm = Shared_ast.Typing.program prgm in
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
(Print.typ prgm.decl_ctx) typ); *)
match backend with

View File

@ -786,46 +786,21 @@ let wrap ctx f e =
let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
(* Infer the type of an expression *)
let infer_types (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) :
(a, A.typed A.mark) A.gexpr A.box =
Expr.map_marks ~f:get_ty_mark
@@ wrap ctx (typecheck_expr_bottom_up ctx empty_env) e
let infer_type (type a m) ctx (e : (a, m A.mark) A.gexpr) =
match Marked.get_mark e with
| A.Typed { ty; _ } -> ty
| A.Untyped _ -> Expr.ty (Bindlib.unbox (infer_types ctx e))
(** Typechecks an expression given an expected type *)
let check_type (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) (tau : A.typ) =
(* todo: consider using the already inferred type if ['m] = [typed] *)
ignore @@ wrap ctx (typecheck_expr_top_down ctx empty_env (ast_to_typ tau)) e
let infer_types_program prg =
let ctx = prg.A.decl_ctx in
let rec process_scopes env = function
| A.Nil -> Bindlib.box A.Nil
| A.ScopeDef
{
scope_next;
scope_name;
scope_body =
{
scope_body_input_struct = s_in;
scope_body_output_struct = s_out;
scope_body_expr = body;
};
} ->
let scope_pos = Marked.get_mark (A.ScopeName.get_info scope_name) in
let struct_ty struct_name =
UnionFind.make (Marked.mark scope_pos (TStruct struct_name))
let expr
(type a)
(ctx : A.decl_ctx)
?(env = empty_env)
?(typ : A.typ option)
(e : (a, 'm) A.gexpr) : (a, A.typed A.mark) A.gexpr A.box =
let fty =
match typ with
| None -> typecheck_expr_bottom_up ctx env
| Some typ -> typecheck_expr_top_down ctx env (ast_to_typ typ)
in
let ty_in = struct_ty s_in in
let ty_out = struct_ty s_out in
let ty_scope =
UnionFind.make (Marked.mark scope_pos (TArrow (ty_in, ty_out)))
in
let rec process_scope_body_expr env = function
Expr.map_marks ~f:get_ty_mark (wrap ctx fty e)
let rec scope_body_expr ctx env ty_out body_expr =
match body_expr with
| A.Result e ->
let e' = wrap ctx (typecheck_expr_top_down ctx env ty_out) e in
let e' = Expr.map_marks ~f:get_ty_mark e' in
@ -842,12 +817,12 @@ let infer_types_program prg =
let e = wrap ctx (typecheck_expr_bottom_up ctx env) e0 in
wrap ctx (fun t -> Bindlib.box (unify ctx e0 (ty e) t)) ty_e;
(* We could use [typecheck_expr_top_down] rather than this manual
unification, but we get better messages with this order of the
[unify] parameters, which keeps location of the type as defined
instead of as inferred. *)
unification, but we get better messages with this order of the [unify]
parameters, which keeps location of the type as defined instead of as
inferred. *)
let var, next = Bindlib.unbind scope_let_next in
let env = { env with vars = Var.Map.add var ty_e env.vars } in
let next = process_scope_body_expr env next in
let next = scope_body_expr ctx env ty_out next in
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
Bindlib.box_apply2
(fun scope_let_expr scope_let_next ->
@ -861,33 +836,45 @@ let infer_types_program prg =
})
(Expr.map_marks ~f:get_ty_mark e)
scope_let_next
let scope_body ctx env body =
let get_pos struct_name =
Marked.get_mark (A.StructName.get_info struct_name)
in
let scope_body_expr =
let var, e = Bindlib.unbind body in
let struct_ty struct_name =
UnionFind.make (Marked.mark (get_pos struct_name) (TStruct struct_name))
in
let ty_in = struct_ty body.A.scope_body_input_struct in
let ty_out = struct_ty body.A.scope_body_output_struct in
let var, e = Bindlib.unbind body.A.scope_body_expr in
let env = { env with vars = Var.Map.add var ty_in env.vars } in
let e' = process_scope_body_expr env e in
Bindlib.bind_var (Var.translate var) e'
in
let e' = scope_body_expr ctx env ty_out e in
( Bindlib.bind_var (Var.translate var) e',
UnionFind.make
(Marked.mark
(get_pos body.A.scope_body_output_struct)
(TArrow (ty_in, ty_out))) )
let rec scopes ctx env = function
| A.Nil -> Bindlib.box A.Nil
| A.ScopeDef def ->
let body_e, ty_scope = scope_body ctx env def.scope_body in
let scope_next =
let scope_var, next = Bindlib.unbind scope_next in
let scope_var, next = Bindlib.unbind def.scope_next in
let env = { env with vars = Var.Map.add scope_var ty_scope env.vars } in
let next' = process_scopes env next in
let next' = scopes ctx env next in
Bindlib.bind_var (Var.translate scope_var) next'
in
Bindlib.box_apply2
(fun scope_body_expr scope_next ->
A.ScopeDef
{
def with
scope_body = { def.scope_body with scope_body_expr };
scope_next;
scope_name;
scope_body =
{
scope_body_input_struct = s_in;
scope_body_output_struct = s_out;
scope_body_expr;
};
})
scope_body_expr scope_next
in
let scopes = Bindlib.unbox (process_scopes empty_env prg.scopes) in
{ A.decl_ctx = ctx; scopes }
body_e scope_next
let program prg =
let scopes = Bindlib.unbox (scopes prg.A.decl_ctx empty_env prg.A.scopes) in
{ prg with scopes }

View File

@ -19,16 +19,16 @@
open Definitions
val infer_types :
decl_ctx -> ('a, untyped mark) gexpr -> ('a, typed mark) gexpr box
(** Infers types everywhere on the given expression, and adds (or replaces) type
annotations on each node *)
type 'e env
val infer_type : decl_ctx -> ('a, 'm mark) gexpr -> typ
(** Gets the outer type of the given expression, using either the existing
annotations or inference *)
val expr :
decl_ctx ->
?env:'e env ->
?typ:typ ->
(('a, 'm mark) gexpr as 'e) ->
('a, typed mark) gexpr box
(** Infers and marks the types for the given expression. If [typ] is provided,
it is assumed to be the outer type and used for inference top-down. *)
val check_type : decl_ctx -> ('a, 'm mark) gexpr -> typ -> unit
val infer_types_program :
('a, untyped mark) gexpr program -> ('a, typed mark) gexpr program
val program : ('a, untyped mark) gexpr program -> ('a, typed mark) gexpr program
(** Typing on whole programs *)