mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 07:51:43 +03:00
Implement typing at the scopelang level
This commit is contained in:
parent
9f7a0f6078
commit
2955ef3235
@ -84,4 +84,37 @@ type 'm program = {
|
||||
program_ctx : decl_ctx;
|
||||
}
|
||||
|
||||
(* let type_program: untyped *)
|
||||
let type_rule decl_ctx env = function
|
||||
| Definition (loc, typ, io, expr) ->
|
||||
let expr' = Typing.expr decl_ctx ~env ~typ expr in
|
||||
Definition (loc, typ, io, Bindlib.unbox expr')
|
||||
| Assertion expr ->
|
||||
let expr' = Typing.expr decl_ctx ~env expr in
|
||||
Assertion (Bindlib.unbox expr')
|
||||
| Call (sc_name, ssc_name) -> Call (sc_name, ssc_name)
|
||||
|
||||
let type_program (prg : 'm program) : typed program =
|
||||
let typing_env =
|
||||
ScopeMap.fold
|
||||
(fun scope_name scope_decl ->
|
||||
Typing.Env.add_scope scope_name
|
||||
(ScopeVarMap.map fst scope_decl.scope_sig))
|
||||
prg.program_scopes Typing.Env.empty
|
||||
in
|
||||
let program_scopes =
|
||||
ScopeMap.map
|
||||
(fun scope_decl ->
|
||||
let typing_env =
|
||||
ScopeVarMap.fold
|
||||
(fun svar (typ, _) env -> Typing.Env.add_scope_var svar typ env)
|
||||
scope_decl.scope_sig typing_env
|
||||
in
|
||||
let scope_decl_rules =
|
||||
List.map
|
||||
(type_rule prg.program_ctx typing_env)
|
||||
scope_decl.scope_decl_rules
|
||||
in
|
||||
{ scope_decl with scope_decl_rules })
|
||||
prg.program_scopes
|
||||
in
|
||||
{ prg with program_scopes }
|
||||
|
@ -77,3 +77,5 @@ type 'm program = {
|
||||
program_scopes : 'm scope_decl ScopeMap.t;
|
||||
program_ctx : decl_ctx;
|
||||
}
|
||||
|
||||
val type_program : 'm program -> typed program
|
||||
|
@ -276,19 +276,37 @@ let op_type (op : A.operator Marked.pos) : unionfind_typ =
|
||||
|
||||
(** {1 Double-directed typing} *)
|
||||
|
||||
type 'e env = {
|
||||
vars : ('e, unionfind_typ) Var.Map.t;
|
||||
scope_vars : A.typ A.ScopeVarMap.t;
|
||||
subscope_vars : A.typ A.ScopeVarMap.t A.SubScopeMap.t;
|
||||
}
|
||||
|
||||
let empty_env =
|
||||
{
|
||||
vars = Var.Map.empty;
|
||||
scope_vars = A.ScopeVarMap.empty;
|
||||
subscope_vars = A.SubScopeMap.empty;
|
||||
module Env = struct
|
||||
type 'e t = {
|
||||
vars : ('e, unionfind_typ) Var.Map.t;
|
||||
scope_vars : A.typ A.ScopeVarMap.t;
|
||||
scopes : A.typ A.ScopeVarMap.t A.ScopeMap.t;
|
||||
}
|
||||
|
||||
let empty =
|
||||
{
|
||||
vars = Var.Map.empty;
|
||||
scope_vars = A.ScopeVarMap.empty;
|
||||
scopes = A.ScopeMap.empty;
|
||||
}
|
||||
|
||||
let get t v = Var.Map.find_opt v t.vars
|
||||
let get_scope_var t sv = A.ScopeVarMap.find_opt sv t.scope_vars
|
||||
|
||||
let get_subscope_var t scope var =
|
||||
Option.bind (A.ScopeMap.find_opt scope t.scopes) (fun vmap ->
|
||||
A.ScopeVarMap.find_opt var vmap)
|
||||
|
||||
let add v tau t = { t with vars = Var.Map.add v tau t.vars }
|
||||
let add_var v typ t = add v (ast_to_typ typ) t
|
||||
|
||||
let add_scope_var v typ t =
|
||||
{ t with scope_vars = A.ScopeVarMap.add v typ t.scope_vars }
|
||||
|
||||
let add_scope scope_name vmap t =
|
||||
{ t with scopes = A.ScopeMap.add scope_name vmap t.scopes }
|
||||
end
|
||||
|
||||
let add_pos e ty = Marked.mark (Expr.pos e) ty
|
||||
let ty (_, { uf; _ }) = uf
|
||||
let ( let+ ) x f = Bindlib.box_apply f x
|
||||
@ -318,7 +336,7 @@ let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e)
|
||||
let rec typecheck_expr_bottom_up :
|
||||
type a.
|
||||
A.decl_ctx ->
|
||||
(a, 'm A.mark) A.gexpr env ->
|
||||
(a, 'm A.mark) A.gexpr Env.t ->
|
||||
(a, 'm A.mark) A.gexpr ->
|
||||
(a, mark) A.gexpr A.box =
|
||||
fun ctx env e ->
|
||||
@ -329,16 +347,19 @@ let rec typecheck_expr_bottom_up :
|
||||
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
||||
let mark_with_uf e1 ?pos ty = mark e1 (unionfind_make ?pos ty) in
|
||||
match Marked.unmark e with
|
||||
| A.ELocation loc as e1 ->
|
||||
| A.ELocation loc as e1 -> (
|
||||
let ty =
|
||||
match loc with
|
||||
| DesugaredScopeVar (v, _) | ScopelangScopeVar v ->
|
||||
A.ScopeVarMap.find (Marked.unmark v) env.scope_vars
|
||||
| SubScopeVar (_s_name, ss_name, v) ->
|
||||
A.ScopeVarMap.find (Marked.unmark v)
|
||||
(A.SubScopeMap.find (Marked.unmark ss_name) env.subscope_vars)
|
||||
Env.get_scope_var env (Marked.unmark v)
|
||||
| SubScopeVar (scope_name, _, v) ->
|
||||
Env.get_subscope_var env scope_name (Marked.unmark v)
|
||||
in
|
||||
Bindlib.box (mark e1 (ast_to_typ ty))
|
||||
match ty with
|
||||
| Some ty -> Bindlib.box (mark e1 (ast_to_typ ty))
|
||||
| None ->
|
||||
Errors.raise_spanned_error pos_e "Reference to %a not found"
|
||||
(Expr.format ctx) e)
|
||||
| A.EStruct (s_name, fmap) ->
|
||||
let+ fmap' =
|
||||
(* This assumes that the fields in fmap and the struct type are already
|
||||
@ -394,7 +415,7 @@ let rec typecheck_expr_bottom_up :
|
||||
unify ctx e e_ty (ty e2');
|
||||
mark (A.ECatch (e1', ex, e2')) e_ty
|
||||
| A.EVar v -> begin
|
||||
match Var.Map.find_opt v env.vars with
|
||||
match Env.get env v with
|
||||
| Some t ->
|
||||
let+ v' = Bindlib.box_var (Var.translate v) in
|
||||
mark v' t
|
||||
@ -466,13 +487,7 @@ let rec typecheck_expr_bottom_up :
|
||||
let xs' = Array.map Var.translate xs in
|
||||
let xstaus = List.mapi (fun i tau -> xs.(i), ast_to_typ tau) taus in
|
||||
let env =
|
||||
{
|
||||
env with
|
||||
vars =
|
||||
List.fold_left
|
||||
(fun env (x, tau) -> Var.Map.add x tau env)
|
||||
env.vars xstaus;
|
||||
}
|
||||
List.fold_left (fun env (x, tau) -> Env.add x tau env) env xstaus
|
||||
in
|
||||
let body' = typecheck_expr_bottom_up ctx env body in
|
||||
let t_func =
|
||||
@ -543,7 +558,7 @@ let rec typecheck_expr_bottom_up :
|
||||
and typecheck_expr_top_down :
|
||||
type a.
|
||||
A.decl_ctx ->
|
||||
(a, 'm A.mark) A.gexpr env ->
|
||||
(a, 'm A.mark) A.gexpr Env.t ->
|
||||
unionfind_typ ->
|
||||
(a, 'm A.mark) A.gexpr ->
|
||||
(a, mark) A.gexpr Bindlib.box =
|
||||
@ -558,16 +573,19 @@ and typecheck_expr_top_down :
|
||||
in
|
||||
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
||||
match Marked.unmark e with
|
||||
| A.ELocation loc as e1 ->
|
||||
| A.ELocation loc as e1 -> (
|
||||
let ty =
|
||||
match loc with
|
||||
| DesugaredScopeVar (v, _) | ScopelangScopeVar v ->
|
||||
A.ScopeVarMap.find (Marked.unmark v) env.scope_vars
|
||||
| SubScopeVar (_s_name, ss_name, v) ->
|
||||
A.ScopeVarMap.find (Marked.unmark v)
|
||||
(A.SubScopeMap.find (Marked.unmark ss_name) env.subscope_vars)
|
||||
Env.get_scope_var env (Marked.unmark v)
|
||||
| SubScopeVar (scope, _, v) ->
|
||||
Env.get_subscope_var env scope (Marked.unmark v)
|
||||
in
|
||||
unify_and_mark (ast_to_typ ty) (fun () -> Bindlib.box e1)
|
||||
match ty with
|
||||
| Some ty -> unify_and_mark (ast_to_typ ty) (fun () -> Bindlib.box e1)
|
||||
| None ->
|
||||
Errors.raise_spanned_error pos_e "Reference to %a not found"
|
||||
(Expr.format ctx) e)
|
||||
| A.EStruct (s_name, fmap) ->
|
||||
unify_and_mark (unionfind_make (TStruct s_name))
|
||||
@@ fun () ->
|
||||
@ -626,8 +644,9 @@ and typecheck_expr_top_down :
|
||||
mark (A.ECatch (e1', ex, e2'))
|
||||
| A.EVar v ->
|
||||
let tau' =
|
||||
try Var.Map.find v env.vars
|
||||
with Not_found ->
|
||||
match Env.get env v with
|
||||
| Some t -> t
|
||||
| None ->
|
||||
Errors.raise_spanned_error pos_e
|
||||
"Variable %s not found in the current context" (Bindlib.name_of v)
|
||||
in
|
||||
@ -712,13 +731,9 @@ and typecheck_expr_top_down :
|
||||
let xs, body = Bindlib.unmbind binder in
|
||||
let xs' = Array.map Var.translate xs in
|
||||
let env =
|
||||
{
|
||||
env with
|
||||
vars =
|
||||
List.fold_left2
|
||||
(fun env x tau_arg -> Var.Map.add x tau_arg env)
|
||||
env.vars (Array.to_list xs) tau_args;
|
||||
}
|
||||
List.fold_left2
|
||||
(fun env x tau_arg -> Env.add x tau_arg env)
|
||||
env (Array.to_list xs) tau_args
|
||||
in
|
||||
let body' = typecheck_expr_top_down ctx env t_ret body in
|
||||
let+ binder' = Bindlib.bind_mvar xs' body' in
|
||||
@ -789,7 +804,7 @@ let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
|
||||
let expr
|
||||
(type a)
|
||||
(ctx : A.decl_ctx)
|
||||
?(env = empty_env)
|
||||
?(env = Env.empty)
|
||||
?(typ : A.typ option)
|
||||
(e : (a, 'm) A.gexpr) : (a, A.typed A.mark) A.gexpr A.box =
|
||||
let fty =
|
||||
@ -821,7 +836,7 @@ let rec scope_body_expr ctx env ty_out body_expr =
|
||||
parameters, which keeps location of the type as defined instead of as
|
||||
inferred. *)
|
||||
let var, next = Bindlib.unbind scope_let_next in
|
||||
let env = { env with vars = Var.Map.add var ty_e env.vars } in
|
||||
let env = Env.add var ty_e env in
|
||||
let next = scope_body_expr ctx env ty_out next in
|
||||
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
|
||||
Bindlib.box_apply2
|
||||
@ -847,7 +862,7 @@ let scope_body ctx env body =
|
||||
let ty_in = struct_ty body.A.scope_body_input_struct in
|
||||
let ty_out = struct_ty body.A.scope_body_output_struct in
|
||||
let var, e = Bindlib.unbind body.A.scope_body_expr in
|
||||
let env = { env with vars = Var.Map.add var ty_in env.vars } in
|
||||
let env = Env.add var ty_in env in
|
||||
let e' = scope_body_expr ctx env ty_out e in
|
||||
( Bindlib.bind_var (Var.translate var) e',
|
||||
UnionFind.make
|
||||
@ -861,7 +876,7 @@ let rec scopes ctx env = function
|
||||
let body_e, ty_scope = scope_body ctx env def.scope_body in
|
||||
let scope_next =
|
||||
let scope_var, next = Bindlib.unbind def.scope_next in
|
||||
let env = { env with vars = Var.Map.add scope_var ty_scope env.vars } in
|
||||
let env = Env.add scope_var ty_scope env in
|
||||
let next' = scopes ctx env next in
|
||||
Bindlib.bind_var (Var.translate scope_var) next'
|
||||
in
|
||||
@ -876,5 +891,5 @@ let rec scopes ctx env = function
|
||||
body_e scope_next
|
||||
|
||||
let program prg =
|
||||
let scopes = Bindlib.unbox (scopes prg.A.decl_ctx empty_env prg.A.scopes) in
|
||||
let scopes = Bindlib.unbox (scopes prg.A.decl_ctx Env.empty prg.A.scopes) in
|
||||
{ prg with scopes }
|
||||
|
@ -19,11 +19,18 @@
|
||||
|
||||
open Definitions
|
||||
|
||||
type 'e env
|
||||
module Env : sig
|
||||
type 'e t
|
||||
|
||||
val empty : 'e t
|
||||
val add_var : 'e Var.t -> typ -> 'e t -> 'e t
|
||||
val add_scope_var : ScopeVar.t -> typ -> 'e t -> 'e t
|
||||
val add_scope : ScopeName.t -> typ ScopeVarMap.t -> 'e t -> 'e t
|
||||
end
|
||||
|
||||
val expr :
|
||||
decl_ctx ->
|
||||
?env:'e env ->
|
||||
?env:'e Env.t ->
|
||||
?typ:typ ->
|
||||
(('a, 'm mark) gexpr as 'e) ->
|
||||
('a, typed mark) gexpr box
|
||||
@ -31,4 +38,5 @@ val expr :
|
||||
it is assumed to be the outer type and used for inference top-down. *)
|
||||
|
||||
val program : ('a, untyped mark) gexpr program -> ('a, typed mark) gexpr program
|
||||
(** Typing on whole programs *)
|
||||
(** Typing on whole programs (as defined in Shared_ast.program, i.e. for the
|
||||
later dcalc/lcalc stages *)
|
||||
|
Loading…
Reference in New Issue
Block a user