From e7dfa41d1a9239dacb0f15d541426afa2125c845 Mon Sep 17 00:00:00 2001 From: Denis Merigoux Date: Sun, 13 Sep 2020 01:05:06 +0200 Subject: [PATCH] Refactoring up to typechecker, building up to interpretation --- src/catala/lambda_calculus/format_lambda.ml | 7 ++ .../lambda_calculus/lambda_typechecker.ml | 117 ++++++++++++------ 2 files changed, 86 insertions(+), 38 deletions(-) diff --git a/src/catala/lambda_calculus/format_lambda.ml b/src/catala/lambda_calculus/format_lambda.ml index 0f16a952..7fa3d5ef 100644 --- a/src/catala/lambda_calculus/format_lambda.ml +++ b/src/catala/lambda_calculus/format_lambda.ml @@ -16,6 +16,13 @@ module IdentMap = Map.Make (String) (* Printing functions for Lambda_ast.term *) +let rec format_typ (ty : Lambda_ast.typ) : string = + match ty with + | TBool -> "bool" + | TInt -> "int" + | TArrow (t1, t2) -> Format.sprintf "(%s) -> (%s)" (format_typ t1) (format_typ t2) + | TDummy -> "??" + (** Operator printer *) let print_op (op : Lambda_ast.op) : string = match op with diff --git a/src/catala/lambda_calculus/lambda_typechecker.ml b/src/catala/lambda_calculus/lambda_typechecker.ml index fa02f7be..4f7885f0 100644 --- a/src/catala/lambda_calculus/lambda_typechecker.ml +++ b/src/catala/lambda_calculus/lambda_typechecker.ml @@ -13,27 +13,24 @@ the License. *) (** Checks that a term is well typed and annotate it *) -let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.term) : - Lambda_ast.term = +let rec type_term (ctxt : Name_resolution.context) (local_ctx : Lambda_ast.typ Uid.LocalVarMap.t) + (((t, pos), _) : Lambda_ast.term) : Lambda_ast.term = match t with - | EVar uid -> - let typ = Name_resolution.get_uid_typ ctxt uid in - ((EVar uid, pos), typ) - | ELocalVar _ -> assert false + | EVar (s, uid) -> + (* so here we can ignore the subscope uid because the uid of the variable already corresponds + to the uid of the variable in its original scope *) + let typ = Name_resolution.get_var_typ ctxt uid in + ((EVar (s, uid), pos), typ) + | ELocalVar uid -> + let typ = Uid.LocalVarMap.find uid local_ctx in + ((ELocalVar uid, pos), typ) | EFun (bindings, body) -> - (* Note that given the context formation process, the binder will already be present in the - context (since we are working with uids), however it's added there for the sake of clarity *) - let ctxt_data = + let local_ctx = List.fold_left - (fun data (uid, arg_typ) -> - let uid_data : Name_resolution.uid_data = - { uid_typ = arg_typ; uid_sort = Name_resolution.IdBinder } - in - Uid.UidMap.add uid uid_data data) - ctxt.data bindings + (fun local_ctx (binding, ty) -> Uid.LocalVarMap.add binding ty local_ctx) + local_ctx bindings in - - let body = typing { ctxt with data = ctxt_data } body in + let body = type_term ctxt local_ctx body in let ret_typ = Lambda_ast.get_typ body in let rec build_typ = function | [] -> ret_typ @@ -42,28 +39,61 @@ let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.t let fun_typ = build_typ bindings in ((EFun (bindings, body), pos), fun_typ) | EApp (f, args) -> - let f = typing ctxt f in + let f = type_term ctxt local_ctx f in let f_typ = Lambda_ast.get_typ f in - let args = List.map (typing ctxt) args in - let args_typ = List.map Lambda_ast.get_typ args in + let args = List.map (type_term ctxt local_ctx) args in + let args_typ = + List.map (fun arg -> (Lambda_ast.get_typ arg, Pos.get_position (fst arg))) args + in let rec check_arrow_typ f_typ args_typ = match (f_typ, args_typ) with | typ, [] -> typ - | TArrow (arg_typ, ret_typ), fst_typ :: typs -> - if arg_typ = fst_typ then check_arrow_typ ret_typ typs else assert false - | _ -> assert false + | Lambda_ast.TArrow (arg_typ, ret_typ), fst_typ :: typs -> + let fst_typ_s = Pos.unmark fst_typ in + if arg_typ = fst_typ_s then check_arrow_typ ret_typ typs + else + Errors.raise_multispanned_error "error when comparing types of function arguments" + [ + ( Some (Printf.sprintf "expected type %s" (Format_lambda.format_typ f_typ)), + Pos.get_position (fst f) ); + ( Some (Printf.sprintf "got type %s" (Format_lambda.format_typ fst_typ_s)), + Pos.get_position fst_typ ); + ] + | _ -> + Errors.raise_multispanned_error "wrong number of arguments for function call" + [ + ( Some (Printf.sprintf "expected type %s" (Format_lambda.format_typ f_typ)), + Pos.get_position (fst f) ); + ( Some + (Printf.sprintf "got type %s" + (String.concat " -> " + (List.map (fun (ty, _) -> Format_lambda.format_typ ty) args_typ))), + Pos.get_position (List.hd args_typ) ); + ] in let ret_typ = check_arrow_typ f_typ args_typ in ((EApp (f, args), pos), ret_typ) | EIfThenElse (t_if, t_then, t_else) -> - let t_if = typing ctxt t_if in + let t_if = type_term ctxt local_ctx t_if in let typ_if = Lambda_ast.get_typ t_if in - let t_then = typing ctxt t_then in + let t_then = type_term ctxt local_ctx t_then in let typ_then = Lambda_ast.get_typ t_then in - let t_else = typing ctxt t_else in + let t_else = type_term ctxt local_ctx t_else in let typ_else = Lambda_ast.get_typ t_else in - if typ_if <> TBool then assert false - else if typ_then <> typ_else then assert false + if typ_if <> TBool then + Errors.raise_spanned_error + (Format.sprintf "expecting type bool, got type %s" (Format_lambda.format_typ typ_if)) + (Pos.get_position (fst t_if)) + else if typ_then <> typ_else then + Errors.raise_multispanned_error + "expecting same types for the true and false branches of the conditional" + [ + ( Some (Format.sprintf "the true branch has type %s" (Format_lambda.format_typ typ_then)), + Pos.get_position (fst t_then) ); + ( Some + (Format.sprintf "the false branch has type %s" (Format_lambda.format_typ typ_else)), + Pos.get_position (fst t_else) ); + ] else ((EIfThenElse (t_if, t_then, t_else), pos), typ_then) | EInt _ | EDec _ -> ((t, pos), TInt) | EBool _ -> ((t, pos), TBool) @@ -72,7 +102,7 @@ let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.t match op with | Binop binop -> ( match binop with - | And | Or -> TArrow (TBool, TArrow (TBool, TBool)) + | And | Or -> Lambda_ast.TArrow (TBool, TArrow (TBool, TBool)) | Add | Sub | Mult | Div -> TArrow (TInt, TArrow (TInt, TInt)) | Lt | Lte | Gt | Gte | Eq | Neq -> TArrow (TInt, TArrow (TInt, TBool)) ) | Unop Minus -> TArrow (TInt, TInt) @@ -81,17 +111,28 @@ let rec type_term (ctxt : Name_resolution.context) (((t, pos), _) : Lambda_ast.t ((t, pos), typ) | EDefault t -> let defaults = - IntMap.map + List.map (fun (just, cons) -> - let just = typing ctxt just in - if Lambda_ast.get_typ just <> TBool then - let cons = typing ctxt cons in - (just, cons) - else assert false) + let just_t = type_term ctxt local_ctx just in + if Lambda_ast.get_typ just_t <> TBool then + let cons = type_term ctxt local_ctx cons in + (just_t, cons) + else + Errors.raise_spanned_error + (Format.sprintf "expected type of default condition to be bool, got %s" + (Format_lambda.format_typ (Lambda_ast.get_typ just))) + (Pos.get_position (fst just))) t.defaults in - let typ_cons = IntMap.choose defaults |> snd |> snd |> Lambda_ast.get_typ in - IntMap.iter - (fun _ (_, cons) -> if Lambda_ast.get_typ cons <> typ_cons then assert false else ()) + let typ_cons = List.hd defaults |> snd |> snd in + List.iter + (fun (_, cons) -> + if Lambda_ast.get_typ cons <> typ_cons then + Errors.raise_spanned_error + (Format.sprintf "expected default condition to be of type %s, got type %s" + (Format_lambda.format_typ (Lambda_ast.get_typ cons)) + (Format_lambda.format_typ typ_cons)) + (Pos.get_position (fst cons)) + else ()) defaults; ((EDefault { t with defaults }, pos), typ_cons)