mirror of
https://github.com/CatalaLang/catala.git
synced 2024-11-08 17:10:22 +03:00
Refactoring up to typechecker, building up to interpretation
This commit is contained in:
parent
7ba05c5633
commit
e7dfa41d1a
@ -16,6 +16,13 @@ module IdentMap = Map.Make (String)
|
||||
|
||||
(* Printing functions for Lambda_ast.term *)
|
||||
|
||||
let rec format_typ (ty : Lambda_ast.typ) : string =
|
||||
match ty with
|
||||
| TBool -> "bool"
|
||||
| TInt -> "int"
|
||||
| TArrow (t1, t2) -> Format.sprintf "(%s) -> (%s)" (format_typ t1) (format_typ t2)
|
||||
| TDummy -> "??"
|
||||
|
||||
(** Operator printer *)
|
||||
let print_op (op : Lambda_ast.op) : string =
|
||||
match op with
|
||||
|
@ -13,27 +13,24 @@
|
||||
the License. *)
|
||||
|
||||
(** Checks that a term is well typed and annotate it *)
|
||||
let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.term) :
|
||||
Lambda_ast.term =
|
||||
let rec type_term (ctxt : Name_resolution.context) (local_ctx : Lambda_ast.typ Uid.LocalVarMap.t)
|
||||
(((t, pos), _) : Lambda_ast.term) : Lambda_ast.term =
|
||||
match t with
|
||||
| EVar uid ->
|
||||
let typ = Name_resolution.get_uid_typ ctxt uid in
|
||||
((EVar uid, pos), typ)
|
||||
| ELocalVar _ -> assert false
|
||||
| EVar (s, uid) ->
|
||||
(* so here we can ignore the subscope uid because the uid of the variable already corresponds
|
||||
to the uid of the variable in its original scope *)
|
||||
let typ = Name_resolution.get_var_typ ctxt uid in
|
||||
((EVar (s, uid), pos), typ)
|
||||
| ELocalVar uid ->
|
||||
let typ = Uid.LocalVarMap.find uid local_ctx in
|
||||
((ELocalVar uid, pos), typ)
|
||||
| EFun (bindings, body) ->
|
||||
(* Note that given the context formation process, the binder will already be present in the
|
||||
context (since we are working with uids), however it's added there for the sake of clarity *)
|
||||
let ctxt_data =
|
||||
let local_ctx =
|
||||
List.fold_left
|
||||
(fun data (uid, arg_typ) ->
|
||||
let uid_data : Name_resolution.uid_data =
|
||||
{ uid_typ = arg_typ; uid_sort = Name_resolution.IdBinder }
|
||||
in
|
||||
Uid.UidMap.add uid uid_data data)
|
||||
ctxt.data bindings
|
||||
(fun local_ctx (binding, ty) -> Uid.LocalVarMap.add binding ty local_ctx)
|
||||
local_ctx bindings
|
||||
in
|
||||
|
||||
let body = typing { ctxt with data = ctxt_data } body in
|
||||
let body = type_term ctxt local_ctx body in
|
||||
let ret_typ = Lambda_ast.get_typ body in
|
||||
let rec build_typ = function
|
||||
| [] -> ret_typ
|
||||
@ -42,28 +39,61 @@ let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.t
|
||||
let fun_typ = build_typ bindings in
|
||||
((EFun (bindings, body), pos), fun_typ)
|
||||
| EApp (f, args) ->
|
||||
let f = typing ctxt f in
|
||||
let f = type_term ctxt local_ctx f in
|
||||
let f_typ = Lambda_ast.get_typ f in
|
||||
let args = List.map (typing ctxt) args in
|
||||
let args_typ = List.map Lambda_ast.get_typ args in
|
||||
let args = List.map (type_term ctxt local_ctx) args in
|
||||
let args_typ =
|
||||
List.map (fun arg -> (Lambda_ast.get_typ arg, Pos.get_position (fst arg))) args
|
||||
in
|
||||
let rec check_arrow_typ f_typ args_typ =
|
||||
match (f_typ, args_typ) with
|
||||
| typ, [] -> typ
|
||||
| TArrow (arg_typ, ret_typ), fst_typ :: typs ->
|
||||
if arg_typ = fst_typ then check_arrow_typ ret_typ typs else assert false
|
||||
| _ -> assert false
|
||||
| Lambda_ast.TArrow (arg_typ, ret_typ), fst_typ :: typs ->
|
||||
let fst_typ_s = Pos.unmark fst_typ in
|
||||
if arg_typ = fst_typ_s then check_arrow_typ ret_typ typs
|
||||
else
|
||||
Errors.raise_multispanned_error "error when comparing types of function arguments"
|
||||
[
|
||||
( Some (Printf.sprintf "expected type %s" (Format_lambda.format_typ f_typ)),
|
||||
Pos.get_position (fst f) );
|
||||
( Some (Printf.sprintf "got type %s" (Format_lambda.format_typ fst_typ_s)),
|
||||
Pos.get_position fst_typ );
|
||||
]
|
||||
| _ ->
|
||||
Errors.raise_multispanned_error "wrong number of arguments for function call"
|
||||
[
|
||||
( Some (Printf.sprintf "expected type %s" (Format_lambda.format_typ f_typ)),
|
||||
Pos.get_position (fst f) );
|
||||
( Some
|
||||
(Printf.sprintf "got type %s"
|
||||
(String.concat " -> "
|
||||
(List.map (fun (ty, _) -> Format_lambda.format_typ ty) args_typ))),
|
||||
Pos.get_position (List.hd args_typ) );
|
||||
]
|
||||
in
|
||||
let ret_typ = check_arrow_typ f_typ args_typ in
|
||||
((EApp (f, args), pos), ret_typ)
|
||||
| EIfThenElse (t_if, t_then, t_else) ->
|
||||
let t_if = typing ctxt t_if in
|
||||
let t_if = type_term ctxt local_ctx t_if in
|
||||
let typ_if = Lambda_ast.get_typ t_if in
|
||||
let t_then = typing ctxt t_then in
|
||||
let t_then = type_term ctxt local_ctx t_then in
|
||||
let typ_then = Lambda_ast.get_typ t_then in
|
||||
let t_else = typing ctxt t_else in
|
||||
let t_else = type_term ctxt local_ctx t_else in
|
||||
let typ_else = Lambda_ast.get_typ t_else in
|
||||
if typ_if <> TBool then assert false
|
||||
else if typ_then <> typ_else then assert false
|
||||
if typ_if <> TBool then
|
||||
Errors.raise_spanned_error
|
||||
(Format.sprintf "expecting type bool, got type %s" (Format_lambda.format_typ typ_if))
|
||||
(Pos.get_position (fst t_if))
|
||||
else if typ_then <> typ_else then
|
||||
Errors.raise_multispanned_error
|
||||
"expecting same types for the true and false branches of the conditional"
|
||||
[
|
||||
( Some (Format.sprintf "the true branch has type %s" (Format_lambda.format_typ typ_then)),
|
||||
Pos.get_position (fst t_then) );
|
||||
( Some
|
||||
(Format.sprintf "the false branch has type %s" (Format_lambda.format_typ typ_else)),
|
||||
Pos.get_position (fst t_else) );
|
||||
]
|
||||
else ((EIfThenElse (t_if, t_then, t_else), pos), typ_then)
|
||||
| EInt _ | EDec _ -> ((t, pos), TInt)
|
||||
| EBool _ -> ((t, pos), TBool)
|
||||
@ -72,7 +102,7 @@ let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.t
|
||||
match op with
|
||||
| Binop binop -> (
|
||||
match binop with
|
||||
| And | Or -> TArrow (TBool, TArrow (TBool, TBool))
|
||||
| And | Or -> Lambda_ast.TArrow (TBool, TArrow (TBool, TBool))
|
||||
| Add | Sub | Mult | Div -> TArrow (TInt, TArrow (TInt, TInt))
|
||||
| Lt | Lte | Gt | Gte | Eq | Neq -> TArrow (TInt, TArrow (TInt, TBool)) )
|
||||
| Unop Minus -> TArrow (TInt, TInt)
|
||||
@ -81,17 +111,28 @@ let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.t
|
||||
((t, pos), typ)
|
||||
| EDefault t ->
|
||||
let defaults =
|
||||
IntMap.map
|
||||
List.map
|
||||
(fun (just, cons) ->
|
||||
let just = typing ctxt just in
|
||||
if Lambda_ast.get_typ just <> TBool then
|
||||
let cons = typing ctxt cons in
|
||||
(just, cons)
|
||||
else assert false)
|
||||
let just_t = type_term ctxt local_ctx just in
|
||||
if Lambda_ast.get_typ just_t <> TBool then
|
||||
let cons = type_term ctxt local_ctx cons in
|
||||
(just_t, cons)
|
||||
else
|
||||
Errors.raise_spanned_error
|
||||
(Format.sprintf "expected type of default condition to be bool, got %s"
|
||||
(Format_lambda.format_typ (Lambda_ast.get_typ just)))
|
||||
(Pos.get_position (fst just)))
|
||||
t.defaults
|
||||
in
|
||||
let typ_cons = IntMap.choose defaults |> snd |> snd |> Lambda_ast.get_typ in
|
||||
IntMap.iter
|
||||
(fun _ (_, cons) -> if Lambda_ast.get_typ cons <> typ_cons then assert false else ())
|
||||
let typ_cons = List.hd defaults |> snd |> snd in
|
||||
List.iter
|
||||
(fun (_, cons) ->
|
||||
if Lambda_ast.get_typ cons <> typ_cons then
|
||||
Errors.raise_spanned_error
|
||||
(Format.sprintf "expected default condition to be of type %s, got type %s"
|
||||
(Format_lambda.format_typ (Lambda_ast.get_typ cons))
|
||||
(Format_lambda.format_typ typ_cons))
|
||||
(Pos.get_position (fst cons))
|
||||
else ())
|
||||
defaults;
|
||||
((EDefault { t with defaults }, pos), typ_cons)
|
||||
|
Loading…
Reference in New Issue
Block a user