mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-09 22:16:10 +03:00
Typing: simplify interface, split code in smaller functions
This commit is contained in:
parent
51f79af13e
commit
05f4bb3537
@ -218,7 +218,7 @@ let driver source_file (options : Cli.options) : int =
|
|||||||
| ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc
|
| ( `Interpret | `Typecheck | `OCaml | `Python | `Scalc | `Lcalc
|
||||||
| `Proof | `Plugin _ ) as backend -> (
|
| `Proof | `Plugin _ ) as backend -> (
|
||||||
Cli.debug_print "Typechecking...";
|
Cli.debug_print "Typechecking...";
|
||||||
let prgm = Shared_ast.Typing.infer_types_program prgm in
|
let prgm = Shared_ast.Typing.program prgm in
|
||||||
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
|
(* Cli.debug_print (Format.asprintf "Typechecking results :@\n%a"
|
||||||
(Print.typ prgm.decl_ctx) typ); *)
|
(Print.typ prgm.decl_ctx) typ); *)
|
||||||
match backend with
|
match backend with
|
||||||
|
@ -786,108 +786,95 @@ let wrap ctx f e =
|
|||||||
let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
|
let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
|
||||||
|
|
||||||
(* Infer the type of an expression *)
|
(* Infer the type of an expression *)
|
||||||
let infer_types (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) :
|
let expr
|
||||||
(a, A.typed A.mark) A.gexpr A.box =
|
(type a)
|
||||||
Expr.map_marks ~f:get_ty_mark
|
(ctx : A.decl_ctx)
|
||||||
@@ wrap ctx (typecheck_expr_bottom_up ctx empty_env) e
|
?(env = empty_env)
|
||||||
|
?(typ : A.typ option)
|
||||||
let infer_type (type a m) ctx (e : (a, m A.mark) A.gexpr) =
|
(e : (a, 'm) A.gexpr) : (a, A.typed A.mark) A.gexpr A.box =
|
||||||
match Marked.get_mark e with
|
let fty =
|
||||||
| A.Typed { ty; _ } -> ty
|
match typ with
|
||||||
| A.Untyped _ -> Expr.ty (Bindlib.unbox (infer_types ctx e))
|
| None -> typecheck_expr_bottom_up ctx env
|
||||||
|
| Some typ -> typecheck_expr_top_down ctx env (ast_to_typ typ)
|
||||||
(** Typechecks an expression given an expected type *)
|
|
||||||
let check_type (type a) (ctx : A.decl_ctx) (e : (a, 'm) A.gexpr) (tau : A.typ) =
|
|
||||||
(* todo: consider using the already inferred type if ['m] = [typed] *)
|
|
||||||
ignore @@ wrap ctx (typecheck_expr_top_down ctx empty_env (ast_to_typ tau)) e
|
|
||||||
|
|
||||||
let infer_types_program prg =
|
|
||||||
let ctx = prg.A.decl_ctx in
|
|
||||||
let rec process_scopes env = function
|
|
||||||
| A.Nil -> Bindlib.box A.Nil
|
|
||||||
| A.ScopeDef
|
|
||||||
{
|
|
||||||
scope_next;
|
|
||||||
scope_name;
|
|
||||||
scope_body =
|
|
||||||
{
|
|
||||||
scope_body_input_struct = s_in;
|
|
||||||
scope_body_output_struct = s_out;
|
|
||||||
scope_body_expr = body;
|
|
||||||
};
|
|
||||||
} ->
|
|
||||||
let scope_pos = Marked.get_mark (A.ScopeName.get_info scope_name) in
|
|
||||||
let struct_ty struct_name =
|
|
||||||
UnionFind.make (Marked.mark scope_pos (TStruct struct_name))
|
|
||||||
in
|
|
||||||
let ty_in = struct_ty s_in in
|
|
||||||
let ty_out = struct_ty s_out in
|
|
||||||
let ty_scope =
|
|
||||||
UnionFind.make (Marked.mark scope_pos (TArrow (ty_in, ty_out)))
|
|
||||||
in
|
|
||||||
let rec process_scope_body_expr env = function
|
|
||||||
| A.Result e ->
|
|
||||||
let e' = wrap ctx (typecheck_expr_top_down ctx env ty_out) e in
|
|
||||||
let e' = Expr.map_marks ~f:get_ty_mark e' in
|
|
||||||
Bindlib.box_apply (fun e -> A.Result e) e'
|
|
||||||
| A.ScopeLet
|
|
||||||
{
|
|
||||||
scope_let_kind;
|
|
||||||
scope_let_typ;
|
|
||||||
scope_let_expr = e0;
|
|
||||||
scope_let_next;
|
|
||||||
scope_let_pos;
|
|
||||||
} ->
|
|
||||||
let ty_e = ast_to_typ scope_let_typ in
|
|
||||||
let e = wrap ctx (typecheck_expr_bottom_up ctx env) e0 in
|
|
||||||
wrap ctx (fun t -> Bindlib.box (unify ctx e0 (ty e) t)) ty_e;
|
|
||||||
(* We could use [typecheck_expr_top_down] rather than this manual
|
|
||||||
unification, but we get better messages with this order of the
|
|
||||||
[unify] parameters, which keeps location of the type as defined
|
|
||||||
instead of as inferred. *)
|
|
||||||
let var, next = Bindlib.unbind scope_let_next in
|
|
||||||
let env = { env with vars = Var.Map.add var ty_e env.vars } in
|
|
||||||
let next = process_scope_body_expr env next in
|
|
||||||
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
|
|
||||||
Bindlib.box_apply2
|
|
||||||
(fun scope_let_expr scope_let_next ->
|
|
||||||
A.ScopeLet
|
|
||||||
{
|
|
||||||
scope_let_kind;
|
|
||||||
scope_let_typ;
|
|
||||||
scope_let_expr;
|
|
||||||
scope_let_next;
|
|
||||||
scope_let_pos;
|
|
||||||
})
|
|
||||||
(Expr.map_marks ~f:get_ty_mark e)
|
|
||||||
scope_let_next
|
|
||||||
in
|
|
||||||
let scope_body_expr =
|
|
||||||
let var, e = Bindlib.unbind body in
|
|
||||||
let env = { env with vars = Var.Map.add var ty_in env.vars } in
|
|
||||||
let e' = process_scope_body_expr env e in
|
|
||||||
Bindlib.bind_var (Var.translate var) e'
|
|
||||||
in
|
|
||||||
let scope_next =
|
|
||||||
let scope_var, next = Bindlib.unbind scope_next in
|
|
||||||
let env = { env with vars = Var.Map.add scope_var ty_scope env.vars } in
|
|
||||||
let next' = process_scopes env next in
|
|
||||||
Bindlib.bind_var (Var.translate scope_var) next'
|
|
||||||
in
|
|
||||||
Bindlib.box_apply2
|
|
||||||
(fun scope_body_expr scope_next ->
|
|
||||||
A.ScopeDef
|
|
||||||
{
|
|
||||||
scope_next;
|
|
||||||
scope_name;
|
|
||||||
scope_body =
|
|
||||||
{
|
|
||||||
scope_body_input_struct = s_in;
|
|
||||||
scope_body_output_struct = s_out;
|
|
||||||
scope_body_expr;
|
|
||||||
};
|
|
||||||
})
|
|
||||||
scope_body_expr scope_next
|
|
||||||
in
|
in
|
||||||
let scopes = Bindlib.unbox (process_scopes empty_env prg.scopes) in
|
Expr.map_marks ~f:get_ty_mark (wrap ctx fty e)
|
||||||
{ A.decl_ctx = ctx; scopes }
|
|
||||||
|
let rec scope_body_expr ctx env ty_out body_expr =
|
||||||
|
match body_expr with
|
||||||
|
| A.Result e ->
|
||||||
|
let e' = wrap ctx (typecheck_expr_top_down ctx env ty_out) e in
|
||||||
|
let e' = Expr.map_marks ~f:get_ty_mark e' in
|
||||||
|
Bindlib.box_apply (fun e -> A.Result e) e'
|
||||||
|
| A.ScopeLet
|
||||||
|
{
|
||||||
|
scope_let_kind;
|
||||||
|
scope_let_typ;
|
||||||
|
scope_let_expr = e0;
|
||||||
|
scope_let_next;
|
||||||
|
scope_let_pos;
|
||||||
|
} ->
|
||||||
|
let ty_e = ast_to_typ scope_let_typ in
|
||||||
|
let e = wrap ctx (typecheck_expr_bottom_up ctx env) e0 in
|
||||||
|
wrap ctx (fun t -> Bindlib.box (unify ctx e0 (ty e) t)) ty_e;
|
||||||
|
(* We could use [typecheck_expr_top_down] rather than this manual
|
||||||
|
unification, but we get better messages with this order of the [unify]
|
||||||
|
parameters, which keeps location of the type as defined instead of as
|
||||||
|
inferred. *)
|
||||||
|
let var, next = Bindlib.unbind scope_let_next in
|
||||||
|
let env = { env with vars = Var.Map.add var ty_e env.vars } in
|
||||||
|
let next = scope_body_expr ctx env ty_out next in
|
||||||
|
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
|
||||||
|
Bindlib.box_apply2
|
||||||
|
(fun scope_let_expr scope_let_next ->
|
||||||
|
A.ScopeLet
|
||||||
|
{
|
||||||
|
scope_let_kind;
|
||||||
|
scope_let_typ;
|
||||||
|
scope_let_expr;
|
||||||
|
scope_let_next;
|
||||||
|
scope_let_pos;
|
||||||
|
})
|
||||||
|
(Expr.map_marks ~f:get_ty_mark e)
|
||||||
|
scope_let_next
|
||||||
|
|
||||||
|
let scope_body ctx env body =
|
||||||
|
let get_pos struct_name =
|
||||||
|
Marked.get_mark (A.StructName.get_info struct_name)
|
||||||
|
in
|
||||||
|
let struct_ty struct_name =
|
||||||
|
UnionFind.make (Marked.mark (get_pos struct_name) (TStruct struct_name))
|
||||||
|
in
|
||||||
|
let ty_in = struct_ty body.A.scope_body_input_struct in
|
||||||
|
let ty_out = struct_ty body.A.scope_body_output_struct in
|
||||||
|
let var, e = Bindlib.unbind body.A.scope_body_expr in
|
||||||
|
let env = { env with vars = Var.Map.add var ty_in env.vars } in
|
||||||
|
let e' = scope_body_expr ctx env ty_out e in
|
||||||
|
( Bindlib.bind_var (Var.translate var) e',
|
||||||
|
UnionFind.make
|
||||||
|
(Marked.mark
|
||||||
|
(get_pos body.A.scope_body_output_struct)
|
||||||
|
(TArrow (ty_in, ty_out))) )
|
||||||
|
|
||||||
|
let rec scopes ctx env = function
|
||||||
|
| A.Nil -> Bindlib.box A.Nil
|
||||||
|
| A.ScopeDef def ->
|
||||||
|
let body_e, ty_scope = scope_body ctx env def.scope_body in
|
||||||
|
let scope_next =
|
||||||
|
let scope_var, next = Bindlib.unbind def.scope_next in
|
||||||
|
let env = { env with vars = Var.Map.add scope_var ty_scope env.vars } in
|
||||||
|
let next' = scopes ctx env next in
|
||||||
|
Bindlib.bind_var (Var.translate scope_var) next'
|
||||||
|
in
|
||||||
|
Bindlib.box_apply2
|
||||||
|
(fun scope_body_expr scope_next ->
|
||||||
|
A.ScopeDef
|
||||||
|
{
|
||||||
|
def with
|
||||||
|
scope_body = { def.scope_body with scope_body_expr };
|
||||||
|
scope_next;
|
||||||
|
})
|
||||||
|
body_e scope_next
|
||||||
|
|
||||||
|
let program prg =
|
||||||
|
let scopes = Bindlib.unbox (scopes prg.A.decl_ctx empty_env prg.A.scopes) in
|
||||||
|
{ prg with scopes }
|
||||||
|
@ -19,16 +19,16 @@
|
|||||||
|
|
||||||
open Definitions
|
open Definitions
|
||||||
|
|
||||||
val infer_types :
|
type 'e env
|
||||||
decl_ctx -> ('a, untyped mark) gexpr -> ('a, typed mark) gexpr box
|
|
||||||
(** Infers types everywhere on the given expression, and adds (or replaces) type
|
|
||||||
annotations on each node *)
|
|
||||||
|
|
||||||
val infer_type : decl_ctx -> ('a, 'm mark) gexpr -> typ
|
val expr :
|
||||||
(** Gets the outer type of the given expression, using either the existing
|
decl_ctx ->
|
||||||
annotations or inference *)
|
?env:'e env ->
|
||||||
|
?typ:typ ->
|
||||||
|
(('a, 'm mark) gexpr as 'e) ->
|
||||||
|
('a, typed mark) gexpr box
|
||||||
|
(** Infers and marks the types for the given expression. If [typ] is provided,
|
||||||
|
it is assumed to be the outer type and used for inference top-down. *)
|
||||||
|
|
||||||
val check_type : decl_ctx -> ('a, 'm mark) gexpr -> typ -> unit
|
val program : ('a, untyped mark) gexpr program -> ('a, typed mark) gexpr program
|
||||||
|
(** Typing on whole programs *)
|
||||||
val infer_types_program :
|
|
||||||
('a, untyped mark) gexpr program -> ('a, typed mark) gexpr program
|
|
||||||
|
Loading…
Reference in New Issue
Block a user