Refactoring up to typechecker, building up to interpretation

This commit is contained in:
Denis Merigoux 2020-09-13 01:05:06 +02:00
parent 7ba05c5633
commit e7dfa41d1a
2 changed files with 86 additions and 38 deletions

View File

@ -16,6 +16,13 @@ module IdentMap = Map.Make (String)
(* Printing functions for Lambda_ast.term *) (* Printing functions for Lambda_ast.term *)
let rec format_typ (ty : Lambda_ast.typ) : string =
match ty with
| TBool -> "bool"
| TInt -> "int"
| TArrow (t1, t2) -> Format.sprintf "(%s) -> (%s)" (format_typ t1) (format_typ t2)
| TDummy -> "??"
(** Operator printer *) (** Operator printer *)
let print_op (op : Lambda_ast.op) : string = let print_op (op : Lambda_ast.op) : string =
match op with match op with

View File

@ -13,27 +13,24 @@
the License. *) the License. *)
(** Checks that a term is well typed and annotate it *) (** Checks that a term is well typed and annotate it *)
let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.term) : let rec type_term (ctxt : Name_resolution.context) (local_ctx : Lambda_ast.typ Uid.LocalVarMap.t)
Lambda_ast.term = (((t, pos), _) : Lambda_ast.term) : Lambda_ast.term =
match t with match t with
| EVar uid -> | EVar (s, uid) ->
let typ = Name_resolution.get_uid_typ ctxt uid in (* so here we can ignore the subscope uid because the uid of the variable already corresponds
((EVar uid, pos), typ) to the uid of the variable in its original scope *)
| ELocalVar _ -> assert false let typ = Name_resolution.get_var_typ ctxt uid in
((EVar (s, uid), pos), typ)
| ELocalVar uid ->
let typ = Uid.LocalVarMap.find uid local_ctx in
((ELocalVar uid, pos), typ)
| EFun (bindings, body) -> | EFun (bindings, body) ->
(* Note that given the context formation process, the binder will already be present in the let local_ctx =
context (since we are working with uids), however it's added there for the sake of clarity *)
let ctxt_data =
List.fold_left List.fold_left
(fun data (uid, arg_typ) -> (fun local_ctx (binding, ty) -> Uid.LocalVarMap.add binding ty local_ctx)
let uid_data : Name_resolution.uid_data = local_ctx bindings
{ uid_typ = arg_typ; uid_sort = Name_resolution.IdBinder }
in in
Uid.UidMap.add uid uid_data data) let body = type_term ctxt local_ctx body in
ctxt.data bindings
in
let body = typing { ctxt with data = ctxt_data } body in
let ret_typ = Lambda_ast.get_typ body in let ret_typ = Lambda_ast.get_typ body in
let rec build_typ = function let rec build_typ = function
| [] -> ret_typ | [] -> ret_typ
@ -42,28 +39,61 @@ let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.t
let fun_typ = build_typ bindings in let fun_typ = build_typ bindings in
((EFun (bindings, body), pos), fun_typ) ((EFun (bindings, body), pos), fun_typ)
| EApp (f, args) -> | EApp (f, args) ->
let f = typing ctxt f in let f = type_term ctxt local_ctx f in
let f_typ = Lambda_ast.get_typ f in let f_typ = Lambda_ast.get_typ f in
let args = List.map (typing ctxt) args in let args = List.map (type_term ctxt local_ctx) args in
let args_typ = List.map Lambda_ast.get_typ args in let args_typ =
List.map (fun arg -> (Lambda_ast.get_typ arg, Pos.get_position (fst arg))) args
in
let rec check_arrow_typ f_typ args_typ = let rec check_arrow_typ f_typ args_typ =
match (f_typ, args_typ) with match (f_typ, args_typ) with
| typ, [] -> typ | typ, [] -> typ
| TArrow (arg_typ, ret_typ), fst_typ :: typs -> | Lambda_ast.TArrow (arg_typ, ret_typ), fst_typ :: typs ->
if arg_typ = fst_typ then check_arrow_typ ret_typ typs else assert false let fst_typ_s = Pos.unmark fst_typ in
| _ -> assert false if arg_typ = fst_typ_s then check_arrow_typ ret_typ typs
else
Errors.raise_multispanned_error "error when comparing types of function arguments"
[
( Some (Printf.sprintf "expected type %s" (Format_lambda.format_typ f_typ)),
Pos.get_position (fst f) );
( Some (Printf.sprintf "got type %s" (Format_lambda.format_typ fst_typ_s)),
Pos.get_position fst_typ );
]
| _ ->
Errors.raise_multispanned_error "wrong number of arguments for function call"
[
( Some (Printf.sprintf "expected type %s" (Format_lambda.format_typ f_typ)),
Pos.get_position (fst f) );
( Some
(Printf.sprintf "got type %s"
(String.concat " -> "
(List.map (fun (ty, _) -> Format_lambda.format_typ ty) args_typ))),
Pos.get_position (List.hd args_typ) );
]
in in
let ret_typ = check_arrow_typ f_typ args_typ in let ret_typ = check_arrow_typ f_typ args_typ in
((EApp (f, args), pos), ret_typ) ((EApp (f, args), pos), ret_typ)
| EIfThenElse (t_if, t_then, t_else) -> | EIfThenElse (t_if, t_then, t_else) ->
let t_if = typing ctxt t_if in let t_if = type_term ctxt local_ctx t_if in
let typ_if = Lambda_ast.get_typ t_if in let typ_if = Lambda_ast.get_typ t_if in
let t_then = typing ctxt t_then in let t_then = type_term ctxt local_ctx t_then in
let typ_then = Lambda_ast.get_typ t_then in let typ_then = Lambda_ast.get_typ t_then in
let t_else = typing ctxt t_else in let t_else = type_term ctxt local_ctx t_else in
let typ_else = Lambda_ast.get_typ t_else in let typ_else = Lambda_ast.get_typ t_else in
if typ_if <> TBool then assert false if typ_if <> TBool then
else if typ_then <> typ_else then assert false Errors.raise_spanned_error
(Format.sprintf "expecting type bool, got type %s" (Format_lambda.format_typ typ_if))
(Pos.get_position (fst t_if))
else if typ_then <> typ_else then
Errors.raise_multispanned_error
"expecting same types for the true and false branches of the conditional"
[
( Some (Format.sprintf "the true branch has type %s" (Format_lambda.format_typ typ_then)),
Pos.get_position (fst t_then) );
( Some
(Format.sprintf "the false branch has type %s" (Format_lambda.format_typ typ_else)),
Pos.get_position (fst t_else) );
]
else ((EIfThenElse (t_if, t_then, t_else), pos), typ_then) else ((EIfThenElse (t_if, t_then, t_else), pos), typ_then)
| EInt _ | EDec _ -> ((t, pos), TInt) | EInt _ | EDec _ -> ((t, pos), TInt)
| EBool _ -> ((t, pos), TBool) | EBool _ -> ((t, pos), TBool)
@ -72,7 +102,7 @@ let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.t
match op with match op with
| Binop binop -> ( | Binop binop -> (
match binop with match binop with
| And | Or -> TArrow (TBool, TArrow (TBool, TBool)) | And | Or -> Lambda_ast.TArrow (TBool, TArrow (TBool, TBool))
| Add | Sub | Mult | Div -> TArrow (TInt, TArrow (TInt, TInt)) | Add | Sub | Mult | Div -> TArrow (TInt, TArrow (TInt, TInt))
| Lt | Lte | Gt | Gte | Eq | Neq -> TArrow (TInt, TArrow (TInt, TBool)) ) | Lt | Lte | Gt | Gte | Eq | Neq -> TArrow (TInt, TArrow (TInt, TBool)) )
| Unop Minus -> TArrow (TInt, TInt) | Unop Minus -> TArrow (TInt, TInt)
@ -81,17 +111,28 @@ let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.t
((t, pos), typ) ((t, pos), typ)
| EDefault t -> | EDefault t ->
let defaults = let defaults =
IntMap.map List.map
(fun (just, cons) -> (fun (just, cons) ->
let just = typing ctxt just in let just_t = type_term ctxt local_ctx just in
if Lambda_ast.get_typ just <> TBool then if Lambda_ast.get_typ just_t <> TBool then
let cons = typing ctxt cons in let cons = type_term ctxt local_ctx cons in
(just, cons) (just_t, cons)
else assert false) else
Errors.raise_spanned_error
(Format.sprintf "expected type of default condition to be bool, got %s"
(Format_lambda.format_typ (Lambda_ast.get_typ just)))
(Pos.get_position (fst just)))
t.defaults t.defaults
in in
let typ_cons = IntMap.choose defaults |> snd |> snd |> Lambda_ast.get_typ in let typ_cons = List.hd defaults |> snd |> snd in
IntMap.iter List.iter
(fun _ (_, cons) -> if Lambda_ast.get_typ cons <> typ_cons then assert false else ()) (fun (_, cons) ->
if Lambda_ast.get_typ cons <> typ_cons then
Errors.raise_spanned_error
(Format.sprintf "expected default condition to be of type %s, got type %s"
(Format_lambda.format_typ (Lambda_ast.get_typ cons))
(Format_lambda.format_typ typ_cons))
(Pos.get_position (fst cons))
else ())
defaults; defaults;
((EDefault { t with defaults }, pos), typ_cons) ((EDefault { t with defaults }, pos), typ_cons)