2020-11-22 22:56:27 +03:00
|
|
|
(* This file is part of the Catala compiler, a specification language for tax
|
|
|
|
and social benefits computation rules. Copyright (C) 2020 Inria, contributor:
|
|
|
|
Denis Merigoux <denis.merigoux@inria.fr>
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
|
|
|
use this file except in compliance with the License. You may obtain a copy of
|
|
|
|
the License at
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
License for the specific language governing permissions and limitations under
|
|
|
|
the License. *)
|
|
|
|
|
|
|
|
(** Typing for the default calculus. Because of the error terms, we perform type
|
|
|
|
inference using the classical W algorithm with union-find unification. *)
|
|
|
|
|
2021-01-21 23:33:04 +03:00
|
|
|
open Utils
|
2022-09-13 16:20:13 +03:00
|
|
|
module A = Definitions
|
2022-07-28 11:36:36 +03:00
|
|
|
|
|
|
|
module Any =
|
|
|
|
Utils.Uid.Make
|
|
|
|
(struct
|
|
|
|
type info = unit
|
|
|
|
|
2022-08-17 17:14:14 +03:00
|
|
|
let to_string _ = "any"
|
2022-07-28 11:36:36 +03:00
|
|
|
let format_info fmt () = Format.fprintf fmt "any"
|
2022-08-25 13:09:51 +03:00
|
|
|
let equal _ _ = true
|
|
|
|
let compare _ _ = 0
|
2022-07-28 11:36:36 +03:00
|
|
|
end)
|
|
|
|
()
|
|
|
|
|
2022-08-25 18:29:00 +03:00
|
|
|
type unionfind_typ = naked_typ Marked.pos UnionFind.elem
|
2022-09-13 16:20:13 +03:00
|
|
|
(** We do not reuse {!type: Shared_ast.typ} because we have to include a new
|
|
|
|
[TAny] variant. Indeed, error terms can have any type and this has to be
|
2022-07-28 11:36:36 +03:00
|
|
|
captured by the type sytem. *)
|
|
|
|
|
2022-08-25 18:29:00 +03:00
|
|
|
and naked_typ =
|
2022-07-28 11:36:36 +03:00
|
|
|
| TLit of A.typ_lit
|
|
|
|
| TArrow of unionfind_typ * unionfind_typ
|
2022-08-23 16:23:52 +03:00
|
|
|
| TTuple of unionfind_typ list
|
|
|
|
| TStruct of A.StructName.t
|
|
|
|
| TEnum of A.EnumName.t
|
|
|
|
| TOption of unionfind_typ
|
2022-07-28 11:36:36 +03:00
|
|
|
| TArray of unionfind_typ
|
|
|
|
| TAny of Any.t
|
|
|
|
|
2022-08-25 18:29:00 +03:00
|
|
|
let rec typ_to_ast (ty : unionfind_typ) : A.typ =
|
2022-07-28 11:36:36 +03:00
|
|
|
let ty, pos = UnionFind.get (UnionFind.find ty) in
|
|
|
|
match ty with
|
2022-08-22 19:53:30 +03:00
|
|
|
| TLit l -> A.TLit l, pos
|
2022-08-23 16:23:52 +03:00
|
|
|
| TTuple ts -> A.TTuple (List.map typ_to_ast ts), pos
|
|
|
|
| TStruct s -> A.TStruct s, pos
|
|
|
|
| TEnum e -> A.TEnum e, pos
|
|
|
|
| TOption t -> A.TOption (typ_to_ast t), pos
|
2022-08-22 19:53:30 +03:00
|
|
|
| TArrow (t1, t2) -> A.TArrow (typ_to_ast t1, typ_to_ast t2), pos
|
|
|
|
| TAny _ -> A.TAny, pos
|
|
|
|
| TArray t1 -> A.TArray (typ_to_ast t1), pos
|
2022-07-28 11:36:36 +03:00
|
|
|
|
2022-08-25 18:29:00 +03:00
|
|
|
let rec ast_to_typ (ty : A.typ) : unionfind_typ =
|
2022-07-28 11:36:36 +03:00
|
|
|
let ty' =
|
|
|
|
match Marked.unmark ty with
|
2022-08-22 19:53:30 +03:00
|
|
|
| A.TLit l -> TLit l
|
|
|
|
| A.TArrow (t1, t2) -> TArrow (ast_to_typ t1, ast_to_typ t2)
|
2022-08-23 16:23:52 +03:00
|
|
|
| A.TTuple ts -> TTuple (List.map ast_to_typ ts)
|
|
|
|
| A.TStruct s -> TStruct s
|
|
|
|
| A.TEnum e -> TEnum e
|
|
|
|
| A.TOption t -> TOption (ast_to_typ t)
|
2022-08-22 19:53:30 +03:00
|
|
|
| A.TArray t -> TArray (ast_to_typ t)
|
|
|
|
| A.TAny -> TAny (Any.fresh ())
|
2022-07-28 11:36:36 +03:00
|
|
|
in
|
|
|
|
UnionFind.make (Marked.same_mark_as ty' ty)
|
2020-11-22 22:56:27 +03:00
|
|
|
|
2020-12-14 20:09:38 +03:00
|
|
|
(** {1 Types and unification} *)
|
|
|
|
|
2022-08-25 18:29:00 +03:00
|
|
|
let typ_needs_parens (t : unionfind_typ) : bool =
|
2020-12-30 00:26:10 +03:00
|
|
|
let t = UnionFind.get (UnionFind.find t) in
|
2022-05-30 12:20:48 +03:00
|
|
|
match Marked.unmark t with TArrow _ | TArray _ -> true | _ -> false
|
2020-11-22 22:56:27 +03:00
|
|
|
|
2021-01-14 02:17:24 +03:00
|
|
|
let rec format_typ
|
2022-08-12 23:42:39 +03:00
|
|
|
(ctx : A.decl_ctx)
|
2021-01-14 02:17:24 +03:00
|
|
|
(fmt : Format.formatter)
|
2022-08-25 18:29:00 +03:00
|
|
|
(naked_typ : unionfind_typ) : unit =
|
2021-01-14 02:17:24 +03:00
|
|
|
let format_typ = format_typ ctx in
|
2022-08-25 20:46:13 +03:00
|
|
|
let format_typ_with_parens (fmt : Format.formatter) (t : unionfind_typ) =
|
2020-12-30 00:26:10 +03:00
|
|
|
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
|
|
|
|
else Format.fprintf fmt "%a" format_typ t
|
|
|
|
in
|
2022-08-25 18:29:00 +03:00
|
|
|
let naked_typ = UnionFind.get (UnionFind.find naked_typ) in
|
|
|
|
match Marked.unmark naked_typ with
|
2022-09-13 16:20:13 +03:00
|
|
|
| TLit l -> Format.fprintf fmt "%a" Print.tlit l
|
2022-08-23 16:23:52 +03:00
|
|
|
| TTuple ts ->
|
2021-01-14 02:17:24 +03:00
|
|
|
Format.fprintf fmt "@[<hov 2>(%a)]"
|
|
|
|
(Format.pp_print_list
|
|
|
|
~pp_sep:(fun fmt () -> Format.fprintf fmt "@ *@ ")
|
|
|
|
(fun fmt t -> Format.fprintf fmt "%a" format_typ t))
|
2020-12-03 22:11:41 +03:00
|
|
|
ts
|
2022-08-23 16:23:52 +03:00
|
|
|
| TStruct s -> Format.fprintf fmt "%a" A.StructName.format_t s
|
|
|
|
| TEnum e -> Format.fprintf fmt "%a" A.EnumName.format_t e
|
|
|
|
| TOption t ->
|
|
|
|
Format.fprintf fmt "@[<hov 2>%a@ %s@]" format_typ_with_parens t "eoption"
|
2020-12-30 00:26:10 +03:00
|
|
|
| TArrow (t1, t2) ->
|
|
|
|
Format.fprintf fmt "@[<hov 2>%a →@ %a@]" format_typ_with_parens t1
|
|
|
|
format_typ t2
|
2022-09-26 15:29:15 +03:00
|
|
|
| TArray t1 -> (
|
|
|
|
match Marked.unmark (UnionFind.get (UnionFind.find t1)) with
|
|
|
|
| TAny _ -> Format.pp_print_string fmt "collection"
|
|
|
|
| _ -> Format.fprintf fmt "@[collection@ %a@]" format_typ t1)
|
2022-09-26 17:05:57 +03:00
|
|
|
| TAny _ -> Format.pp_print_string fmt "<any>"
|
2020-11-22 22:56:27 +03:00
|
|
|
|
2022-08-26 12:06:00 +03:00
|
|
|
exception Type_error of A.any_expr * unionfind_typ * unionfind_typ
|
2022-07-11 12:32:23 +03:00
|
|
|
|
2022-07-28 11:36:36 +03:00
|
|
|
type mark = { pos : Pos.t; uf : unionfind_typ }
|
|
|
|
|
2022-09-16 18:33:09 +03:00
|
|
|
(** Raises an error if unification cannot be performed. The position annotation
|
|
|
|
of the second [unionfind_typ] argument is propagated (unless it is [TAny]). *)
|
2021-01-14 02:17:24 +03:00
|
|
|
let rec unify
|
2022-08-12 23:42:39 +03:00
|
|
|
(ctx : A.decl_ctx)
|
2022-08-25 17:31:32 +03:00
|
|
|
(e : ('a, 'm A.mark) A.gexpr) (* used for error context *)
|
2022-08-25 18:29:00 +03:00
|
|
|
(t1 : unionfind_typ)
|
|
|
|
(t2 : unionfind_typ) : unit =
|
2021-01-14 02:17:24 +03:00
|
|
|
let unify = unify ctx in
|
2022-07-22 12:22:54 +03:00
|
|
|
(* Cli.debug_format "Unifying %a and %a" (format_typ ctx) t1 (format_typ ctx)
|
|
|
|
t2; *)
|
2020-11-22 22:56:27 +03:00
|
|
|
let t1_repr = UnionFind.get (UnionFind.find t1) in
|
|
|
|
let t2_repr = UnionFind.get (UnionFind.find t2) in
|
2022-07-28 11:36:36 +03:00
|
|
|
let raise_type_error () = raise (Type_error (A.AnyExpr e, t1, t2)) in
|
2022-09-16 18:33:09 +03:00
|
|
|
let () =
|
2022-07-11 12:32:23 +03:00
|
|
|
match Marked.unmark t1_repr, Marked.unmark t2_repr with
|
2022-09-16 18:33:09 +03:00
|
|
|
| TLit tl1, TLit tl2 -> if tl1 <> tl2 then raise_type_error ()
|
2022-07-11 12:32:23 +03:00
|
|
|
| TArrow (t11, t12), TArrow (t21, t22) ->
|
2022-09-16 19:15:30 +03:00
|
|
|
unify e t12 t22;
|
|
|
|
unify e t11 t21
|
2022-08-23 16:23:52 +03:00
|
|
|
| TTuple ts1, TTuple ts2 ->
|
2022-09-16 18:33:09 +03:00
|
|
|
if List.length ts1 = List.length ts2 then List.iter2 (unify e) ts1 ts2
|
2022-07-11 12:32:23 +03:00
|
|
|
else raise_type_error ()
|
2022-08-23 16:23:52 +03:00
|
|
|
| TStruct s1, TStruct s2 ->
|
2022-09-16 18:33:09 +03:00
|
|
|
if not (A.StructName.equal s1 s2) then raise_type_error ()
|
2022-08-23 16:23:52 +03:00
|
|
|
| TEnum e1, TEnum e2 ->
|
2022-09-16 18:33:09 +03:00
|
|
|
if not (A.EnumName.equal e1 e2) then raise_type_error ()
|
|
|
|
| TOption t1, TOption t2 -> unify e t1 t2
|
|
|
|
| TArray t1', TArray t2' -> unify e t1' t2'
|
|
|
|
| TAny _, _ | _, TAny _ -> ()
|
2022-08-23 16:23:52 +03:00
|
|
|
| ( ( TLit _ | TArrow _ | TTuple _ | TStruct _ | TEnum _ | TOption _
|
|
|
|
| TArray _ ),
|
|
|
|
_ ) ->
|
|
|
|
raise_type_error ()
|
2020-12-30 03:02:04 +03:00
|
|
|
in
|
2022-09-16 18:33:09 +03:00
|
|
|
ignore
|
|
|
|
@@ UnionFind.merge
|
|
|
|
(fun t1 t2 -> match Marked.unmark t2 with TAny _ -> t1 | _ -> t2)
|
|
|
|
t1 t2
|
2020-11-22 22:56:27 +03:00
|
|
|
|
2022-07-11 12:32:23 +03:00
|
|
|
let handle_type_error ctx e t1 t2 =
|
|
|
|
(* TODO: if we get weird error messages, then it means that we should use the
|
|
|
|
persistent version of the union-find data structure. *)
|
2022-07-28 11:36:36 +03:00
|
|
|
let pos =
|
|
|
|
match e with
|
|
|
|
| A.AnyExpr e -> (
|
|
|
|
match Marked.get_mark e with Untyped { pos } | Typed { pos; _ } -> pos)
|
|
|
|
in
|
2022-07-11 12:32:23 +03:00
|
|
|
let t1_repr = UnionFind.get (UnionFind.find t1) in
|
|
|
|
let t2_repr = UnionFind.get (UnionFind.find t2) in
|
|
|
|
let t1_pos = Marked.get_mark t1_repr in
|
|
|
|
let t2_pos = Marked.get_mark t2_repr in
|
|
|
|
let unformat_typ typ =
|
|
|
|
let buf = Buffer.create 59 in
|
|
|
|
let ppf = Format.formatter_of_buffer buf in
|
|
|
|
(* set infinite width to disable line cuts *)
|
|
|
|
Format.pp_set_margin ppf max_int;
|
|
|
|
format_typ ctx ppf typ;
|
|
|
|
Format.pp_print_flush ppf ();
|
|
|
|
Buffer.contents buf
|
|
|
|
in
|
|
|
|
let t1_s fmt () =
|
|
|
|
Cli.format_with_style [ANSITerminal.yellow] fmt (unformat_typ t1)
|
|
|
|
in
|
|
|
|
let t2_s fmt () =
|
|
|
|
Cli.format_with_style [ANSITerminal.yellow] fmt (unformat_typ t2)
|
|
|
|
in
|
|
|
|
Errors.raise_multispanned_error
|
|
|
|
[
|
|
|
|
( Some
|
|
|
|
(Format.asprintf
|
|
|
|
"Error coming from typechecking the following expression:"),
|
2022-07-28 11:36:36 +03:00
|
|
|
pos );
|
2022-07-11 12:32:23 +03:00
|
|
|
Some (Format.asprintf "Type %a coming from expression:" t1_s ()), t1_pos;
|
|
|
|
Some (Format.asprintf "Type %a coming from expression:" t2_s ()), t2_pos;
|
|
|
|
]
|
|
|
|
"Error during typechecking, incompatible types:\n%a %a\n%a %a"
|
|
|
|
(Cli.format_with_style [ANSITerminal.blue; ANSITerminal.Bold])
|
|
|
|
"-->" t1_s ()
|
|
|
|
(Cli.format_with_style [ANSITerminal.blue; ANSITerminal.Bold])
|
|
|
|
"-->" t2_s ()
|
|
|
|
|
2022-09-13 16:20:13 +03:00
|
|
|
let lit_type (type a) (lit : a A.glit) : naked_typ =
|
|
|
|
match lit with
|
|
|
|
| LBool _ -> TLit TBool
|
|
|
|
| LInt _ -> TLit TInt
|
|
|
|
| LRat _ -> TLit TRat
|
|
|
|
| LMoney _ -> TLit TMoney
|
|
|
|
| LDate _ -> TLit TDate
|
|
|
|
| LDuration _ -> TLit TDuration
|
|
|
|
| LUnit -> TLit TUnit
|
|
|
|
| LEmptyError -> TAny (Any.fresh ())
|
|
|
|
|
2020-12-14 20:09:38 +03:00
|
|
|
(** Operators have a single type, instead of being polymorphic with constraints.
|
|
|
|
This allows us to have a simpler type system, while we argue the syntactic
|
|
|
|
burden of operator annotations helps the programmer visualize the type flow
|
|
|
|
in the code. *)
|
2022-08-25 18:29:00 +03:00
|
|
|
let op_type (op : A.operator Marked.pos) : unionfind_typ =
|
2022-05-30 12:20:48 +03:00
|
|
|
let pos = Marked.get_mark op in
|
2020-12-09 16:51:22 +03:00
|
|
|
let bt = UnionFind.make (TLit TBool, pos) in
|
|
|
|
let it = UnionFind.make (TLit TInt, pos) in
|
|
|
|
let rt = UnionFind.make (TLit TRat, pos) in
|
|
|
|
let mt = UnionFind.make (TLit TMoney, pos) in
|
2020-12-10 13:35:56 +03:00
|
|
|
let dut = UnionFind.make (TLit TDuration, pos) in
|
|
|
|
let dat = UnionFind.make (TLit TDate, pos) in
|
2020-12-30 00:26:10 +03:00
|
|
|
let any = UnionFind.make (TAny (Any.fresh ()), pos) in
|
2020-12-28 01:53:02 +03:00
|
|
|
let array_any = UnionFind.make (TArray any, pos) in
|
2020-12-30 00:26:10 +03:00
|
|
|
let any2 = UnionFind.make (TAny (Any.fresh ()), pos) in
|
2021-01-10 20:11:46 +03:00
|
|
|
let array_any2 = UnionFind.make (TArray any2, pos) in
|
2020-11-24 13:27:23 +03:00
|
|
|
let arr x y = UnionFind.make (TArrow (x, y), pos) in
|
2022-05-30 12:20:48 +03:00
|
|
|
match Marked.unmark op with
|
2020-12-30 00:26:10 +03:00
|
|
|
| A.Ternop A.Fold ->
|
|
|
|
arr (arr any2 (arr any any2)) (arr any2 (arr array_any any2))
|
2021-03-16 20:34:59 +03:00
|
|
|
| A.Binop (A.And | A.Or | A.Xor) -> arr bt (arr bt bt)
|
2020-12-09 16:51:22 +03:00
|
|
|
| A.Binop (A.Add KInt | A.Sub KInt | A.Mult KInt | A.Div KInt) ->
|
|
|
|
arr it (arr it it)
|
|
|
|
| A.Binop (A.Add KRat | A.Sub KRat | A.Mult KRat | A.Div KRat) ->
|
|
|
|
arr rt (arr rt rt)
|
2020-12-09 20:14:52 +03:00
|
|
|
| A.Binop (A.Add KMoney | A.Sub KMoney) -> arr mt (arr mt mt)
|
2020-12-10 13:35:56 +03:00
|
|
|
| A.Binop (A.Add KDuration | A.Sub KDuration) -> arr dut (arr dut dut)
|
|
|
|
| A.Binop (A.Sub KDate) -> arr dat (arr dat dut)
|
|
|
|
| A.Binop (A.Add KDate) -> arr dat (arr dut dat)
|
2022-05-31 21:00:52 +03:00
|
|
|
| A.Binop (A.Mult KDuration) -> arr dut (arr it dut)
|
2020-12-09 20:14:52 +03:00
|
|
|
| A.Binop (A.Div KMoney) -> arr mt (arr mt rt)
|
|
|
|
| A.Binop (A.Mult KMoney) -> arr mt (arr rt mt)
|
2020-12-09 16:51:22 +03:00
|
|
|
| A.Binop (A.Lt KInt | A.Lte KInt | A.Gt KInt | A.Gte KInt) ->
|
|
|
|
arr it (arr it bt)
|
|
|
|
| A.Binop (A.Lt KRat | A.Lte KRat | A.Gt KRat | A.Gte KRat) ->
|
|
|
|
arr rt (arr rt bt)
|
|
|
|
| A.Binop (A.Lt KMoney | A.Lte KMoney | A.Gt KMoney | A.Gte KMoney) ->
|
|
|
|
arr mt (arr mt bt)
|
2020-12-10 13:35:56 +03:00
|
|
|
| A.Binop (A.Lt KDate | A.Lte KDate | A.Gt KDate | A.Gte KDate) ->
|
|
|
|
arr dat (arr dat bt)
|
|
|
|
| A.Binop (A.Lt KDuration | A.Lte KDuration | A.Gt KDuration | A.Gte KDuration)
|
|
|
|
->
|
|
|
|
arr dut (arr dut bt)
|
2020-11-24 13:27:23 +03:00
|
|
|
| A.Binop (A.Eq | A.Neq) -> arr any (arr any bt)
|
2021-01-10 20:11:46 +03:00
|
|
|
| A.Binop A.Map -> arr (arr any any2) (arr array_any array_any2)
|
|
|
|
| A.Binop A.Filter -> arr (arr any bt) (arr array_any array_any)
|
2021-07-08 17:27:46 +03:00
|
|
|
| A.Binop A.Concat -> arr array_any (arr array_any array_any)
|
2020-12-09 16:51:22 +03:00
|
|
|
| A.Unop (A.Minus KInt) -> arr it it
|
|
|
|
| A.Unop (A.Minus KRat) -> arr rt rt
|
|
|
|
| A.Unop (A.Minus KMoney) -> arr mt mt
|
2020-12-10 13:35:56 +03:00
|
|
|
| A.Unop (A.Minus KDuration) -> arr dut dut
|
2020-11-24 13:27:23 +03:00
|
|
|
| A.Unop A.Not -> arr bt bt
|
2021-01-20 21:58:48 +03:00
|
|
|
| A.Unop (A.Log (A.PosRecordIfTrueBool, _)) -> arr bt bt
|
2020-12-11 12:51:46 +03:00
|
|
|
| A.Unop (A.Log _) -> arr any any
|
2020-12-28 01:53:02 +03:00
|
|
|
| A.Unop A.Length -> arr array_any it
|
2021-01-05 18:00:15 +03:00
|
|
|
| A.Unop A.GetDay -> arr dat it
|
|
|
|
| A.Unop A.GetMonth -> arr dat it
|
|
|
|
| A.Unop A.GetYear -> arr dat it
|
2022-07-21 15:11:56 +03:00
|
|
|
| A.Unop A.FirstDayOfMonth -> arr dat dat
|
|
|
|
| A.Unop A.LastDayOfMonth -> arr dat dat
|
2022-03-17 14:30:14 +03:00
|
|
|
| A.Unop A.RoundMoney -> arr mt mt
|
2022-04-29 22:18:15 +03:00
|
|
|
| A.Unop A.RoundDecimal -> arr rt rt
|
2021-01-04 02:13:59 +03:00
|
|
|
| A.Unop A.IntToRat -> arr it rt
|
2022-07-19 12:48:27 +03:00
|
|
|
| A.Unop A.MoneyToRat -> arr mt rt
|
|
|
|
| A.Unop A.RatToMoney -> arr rt mt
|
2022-08-30 16:06:45 +03:00
|
|
|
| Binop (Mult KDate) | Binop (Div (KDate | KDuration)) | Unop (Minus KDate) ->
|
2022-03-08 15:04:27 +03:00
|
|
|
Errors.raise_spanned_error pos "This operator is not available!"
|
2020-11-24 13:27:23 +03:00
|
|
|
|
2020-12-14 20:09:38 +03:00
|
|
|
(** {1 Double-directed typing} *)
|
2020-11-22 22:56:27 +03:00
|
|
|
|
2022-09-26 17:32:02 +03:00
|
|
|
module Env = struct
|
|
|
|
type 'e t = {
|
|
|
|
vars : ('e, unionfind_typ) Var.Map.t;
|
|
|
|
scope_vars : A.typ A.ScopeVarMap.t;
|
|
|
|
scopes : A.typ A.ScopeVarMap.t A.ScopeMap.t;
|
2022-09-14 16:36:24 +03:00
|
|
|
}
|
2020-11-23 12:44:06 +03:00
|
|
|
|
2022-09-26 17:32:02 +03:00
|
|
|
let empty =
|
|
|
|
{
|
|
|
|
vars = Var.Map.empty;
|
|
|
|
scope_vars = A.ScopeVarMap.empty;
|
|
|
|
scopes = A.ScopeMap.empty;
|
|
|
|
}
|
|
|
|
|
|
|
|
let get t v = Var.Map.find_opt v t.vars
|
|
|
|
let get_scope_var t sv = A.ScopeVarMap.find_opt sv t.scope_vars
|
|
|
|
|
|
|
|
let get_subscope_var t scope var =
|
|
|
|
Option.bind (A.ScopeMap.find_opt scope t.scopes) (fun vmap ->
|
|
|
|
A.ScopeVarMap.find_opt var vmap)
|
|
|
|
|
|
|
|
let add v tau t = { t with vars = Var.Map.add v tau t.vars }
|
|
|
|
let add_var v typ t = add v (ast_to_typ typ) t
|
|
|
|
|
|
|
|
let add_scope_var v typ t =
|
|
|
|
{ t with scope_vars = A.ScopeVarMap.add v typ t.scope_vars }
|
|
|
|
|
|
|
|
let add_scope scope_name vmap t =
|
|
|
|
{ t with scopes = A.ScopeMap.add scope_name vmap t.scopes }
|
|
|
|
end
|
|
|
|
|
2022-09-13 16:20:13 +03:00
|
|
|
let add_pos e ty = Marked.mark (Expr.pos e) ty
|
2022-07-28 11:36:36 +03:00
|
|
|
let ty (_, { uf; _ }) = uf
|
2022-06-23 15:04:51 +03:00
|
|
|
let ( let+ ) x f = Bindlib.box_apply f x
|
|
|
|
let ( and+ ) x1 x2 = Bindlib.box_pair x1 x2
|
|
|
|
|
2022-07-19 22:41:55 +03:00
|
|
|
(* Maps a boxing function on a list, returning a boxed list *)
|
|
|
|
let bmap (f : 'a -> 'b Bindlib.box) (es : 'a list) : 'b list Bindlib.box =
|
2022-06-23 15:04:51 +03:00
|
|
|
List.fold_right
|
|
|
|
(fun e acc ->
|
2022-07-21 15:11:56 +03:00
|
|
|
let+ e' = f e and+ acc in
|
2022-06-23 15:04:51 +03:00
|
|
|
e' :: acc)
|
|
|
|
es (Bindlib.box [])
|
2020-12-14 20:09:38 +03:00
|
|
|
|
2022-07-19 22:41:55 +03:00
|
|
|
(* Likewise, but with a function of two arguments on two lists of identical
|
|
|
|
lengths *)
|
|
|
|
let bmap2 (f : 'a -> 'b -> 'c Bindlib.box) (es : 'a list) (xs : 'b list) :
|
|
|
|
'c list Bindlib.box =
|
2022-06-23 15:04:51 +03:00
|
|
|
List.fold_right2
|
|
|
|
(fun e x acc ->
|
2022-07-21 15:11:56 +03:00
|
|
|
let+ e' = f e x and+ acc in
|
2022-06-23 15:04:51 +03:00
|
|
|
e' :: acc)
|
|
|
|
es xs (Bindlib.box [])
|
|
|
|
|
|
|
|
let box_ty e = Bindlib.unbox (Bindlib.box_apply ty e)
|
2020-11-22 22:56:27 +03:00
|
|
|
|
2020-12-14 20:09:38 +03:00
|
|
|
(** Infers the most permissive type from an expression *)
|
2022-09-13 16:20:13 +03:00
|
|
|
let rec typecheck_expr_bottom_up :
|
|
|
|
type a.
|
|
|
|
A.decl_ctx ->
|
2022-09-26 17:32:02 +03:00
|
|
|
(a, 'm A.mark) A.gexpr Env.t ->
|
2022-09-13 16:20:13 +03:00
|
|
|
(a, 'm A.mark) A.gexpr ->
|
|
|
|
(a, mark) A.gexpr A.box =
|
|
|
|
fun ctx env e ->
|
2022-08-17 17:14:14 +03:00
|
|
|
(* Cli.debug_format "Looking for type of %a" (Expr.format ~debug:true ctx)
|
|
|
|
e; *)
|
2022-09-13 16:20:13 +03:00
|
|
|
let pos_e = Expr.pos e in
|
|
|
|
let mark e uf = Marked.mark { uf; pos = pos_e } e in
|
2022-05-31 19:38:14 +03:00
|
|
|
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
|
|
|
let mark_with_uf e1 ?pos ty = mark e1 (unionfind_make ?pos ty) in
|
|
|
|
match Marked.unmark e with
|
2022-09-26 17:32:02 +03:00
|
|
|
| A.ELocation loc as e1 -> (
|
2022-09-14 16:36:24 +03:00
|
|
|
let ty =
|
|
|
|
match loc with
|
|
|
|
| DesugaredScopeVar (v, _) | ScopelangScopeVar v ->
|
2022-09-26 17:32:02 +03:00
|
|
|
Env.get_scope_var env (Marked.unmark v)
|
|
|
|
| SubScopeVar (scope_name, _, v) ->
|
|
|
|
Env.get_subscope_var env scope_name (Marked.unmark v)
|
2022-09-14 16:36:24 +03:00
|
|
|
in
|
2022-09-26 17:32:02 +03:00
|
|
|
match ty with
|
|
|
|
| Some ty -> Bindlib.box (mark e1 (ast_to_typ ty))
|
|
|
|
| None ->
|
|
|
|
Errors.raise_spanned_error pos_e "Reference to %a not found"
|
|
|
|
(Expr.format ctx) e)
|
2022-09-14 16:36:24 +03:00
|
|
|
| A.EStruct (s_name, fmap) ->
|
|
|
|
let+ fmap' =
|
|
|
|
(* This assumes that the fields in fmap and the struct type are already
|
|
|
|
ensured to be the same *)
|
|
|
|
List.fold_left
|
|
|
|
(fun fmap' (f_name, f_ty) ->
|
|
|
|
let f_e = A.StructFieldMap.find f_name fmap in
|
|
|
|
let+ fmap'
|
|
|
|
and+ f_e' = typecheck_expr_top_down ctx env (ast_to_typ f_ty) f_e in
|
|
|
|
A.StructFieldMap.add f_name f_e' fmap')
|
|
|
|
(Bindlib.box A.StructFieldMap.empty)
|
|
|
|
(A.StructMap.find s_name ctx.A.ctx_structs)
|
|
|
|
in
|
|
|
|
mark_with_uf (A.EStruct (s_name, fmap')) (TStruct s_name)
|
|
|
|
| A.EStructAccess (e_struct, f_name, s_name) ->
|
|
|
|
let f_ty =
|
|
|
|
ast_to_typ (List.assoc f_name (A.StructMap.find s_name ctx.A.ctx_structs))
|
|
|
|
in
|
|
|
|
let+ e_struct' =
|
|
|
|
typecheck_expr_top_down ctx env (unionfind_make (TStruct s_name)) e_struct
|
|
|
|
in
|
|
|
|
mark (A.EStructAccess (e_struct', f_name, s_name)) f_ty
|
|
|
|
| A.EEnumInj (e_enum, c_name, e_name) ->
|
|
|
|
let c_ty =
|
|
|
|
ast_to_typ (List.assoc c_name (A.EnumMap.find e_name ctx.A.ctx_enums))
|
|
|
|
in
|
|
|
|
let+ e_enum' =
|
|
|
|
typecheck_expr_top_down ctx env (unionfind_make (TEnum e_name)) e_enum
|
|
|
|
in
|
|
|
|
mark (A.EEnumInj (e_enum', c_name, e_name)) c_ty
|
|
|
|
| A.EMatchS (e1, e_name, cases) ->
|
|
|
|
let cases_ty = A.EnumMap.find e_name ctx.A.ctx_enums in
|
|
|
|
let t_ret = unionfind_make ~pos:e1 (TAny (Any.fresh ())) in
|
|
|
|
let+ e1' =
|
|
|
|
typecheck_expr_top_down ctx env (unionfind_make (TEnum e_name)) e1
|
|
|
|
and+ cases' =
|
|
|
|
A.EnumConstructorMap.fold
|
|
|
|
(fun c_name e cases' ->
|
|
|
|
let c_ty = List.assoc c_name cases_ty in
|
|
|
|
let e_ty = unionfind_make ~pos:e (TArrow (ast_to_typ c_ty, t_ret)) in
|
|
|
|
let+ cases' and+ e' = typecheck_expr_top_down ctx env e_ty e in
|
|
|
|
A.EnumConstructorMap.add c_name e' cases')
|
|
|
|
cases
|
|
|
|
(Bindlib.box A.EnumConstructorMap.empty)
|
|
|
|
in
|
|
|
|
mark (A.EMatchS (e1', e_name, cases')) t_ret
|
|
|
|
| A.ERaise ex ->
|
|
|
|
Bindlib.box (mark_with_uf (A.ERaise ex) (TAny (Any.fresh ())))
|
|
|
|
| A.ECatch (e1, ex, e2) ->
|
|
|
|
let+ e1' = typecheck_expr_bottom_up ctx env e1
|
|
|
|
and+ e2' = typecheck_expr_bottom_up ctx env e2 in
|
|
|
|
let e_ty = ty e1' in
|
|
|
|
unify ctx e e_ty (ty e2');
|
|
|
|
mark (A.ECatch (e1', ex, e2')) e_ty
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EVar v -> begin
|
2022-09-26 17:32:02 +03:00
|
|
|
match Env.get env v with
|
2022-05-31 19:38:14 +03:00
|
|
|
| Some t ->
|
2022-09-13 16:20:13 +03:00
|
|
|
let+ v' = Bindlib.box_var (Var.translate v) in
|
2022-06-23 15:04:51 +03:00
|
|
|
mark v' t
|
2022-05-31 19:38:14 +03:00
|
|
|
| None ->
|
2022-09-13 16:20:13 +03:00
|
|
|
Errors.raise_spanned_error (Expr.pos e)
|
2022-07-22 12:22:54 +03:00
|
|
|
"Variable %s not found in the current context." (Bindlib.name_of v)
|
2022-07-11 12:34:01 +03:00
|
|
|
end
|
2022-09-13 16:20:13 +03:00
|
|
|
| A.ELit lit as e1 -> Bindlib.box @@ mark_with_uf e1 (lit_type lit)
|
2022-08-23 16:23:52 +03:00
|
|
|
| A.ETuple (es, None) ->
|
2022-06-23 15:04:51 +03:00
|
|
|
let+ es = bmap (typecheck_expr_bottom_up ctx env) es in
|
2022-09-13 16:20:13 +03:00
|
|
|
mark_with_uf (A.ETuple (es, None)) (TTuple (List.map ty es))
|
2022-08-23 16:23:52 +03:00
|
|
|
| A.ETuple (es, Some s_name) ->
|
|
|
|
let tys =
|
|
|
|
List.map
|
|
|
|
(fun (_, ty) -> ast_to_typ ty)
|
|
|
|
(A.StructMap.find s_name ctx.A.ctx_structs)
|
|
|
|
in
|
|
|
|
let+ es = bmap2 (typecheck_expr_top_down ctx env) tys es in
|
2022-09-13 16:20:13 +03:00
|
|
|
mark_with_uf (A.ETuple (es, Some s_name)) (TStruct s_name)
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.ETupleAccess (e1, n, s, typs) -> begin
|
2022-06-23 15:04:51 +03:00
|
|
|
let utyps = List.map ast_to_typ typs in
|
2022-08-23 16:23:52 +03:00
|
|
|
let tuple_ty = match s with None -> TTuple utyps | Some s -> TStruct s in
|
|
|
|
let+ e1 = typecheck_expr_top_down ctx env (unionfind_make tuple_ty) e1 in
|
2022-05-31 19:38:14 +03:00
|
|
|
match List.nth_opt utyps n with
|
2022-09-13 16:20:13 +03:00
|
|
|
| Some t' -> mark (A.ETupleAccess (e1, n, s, typs)) t'
|
2022-05-31 19:38:14 +03:00
|
|
|
| None ->
|
2022-07-28 11:36:36 +03:00
|
|
|
Errors.raise_spanned_error (Marked.get_mark e1).pos
|
2022-05-31 19:38:14 +03:00
|
|
|
"Expression should have a tuple type with at least %d elements but \
|
|
|
|
only has %d"
|
|
|
|
n (List.length typs)
|
2022-07-11 12:34:01 +03:00
|
|
|
end
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EInj (e1, n, e_name, ts) ->
|
|
|
|
let ts' = List.map ast_to_typ ts in
|
|
|
|
let ts_n =
|
|
|
|
match List.nth_opt ts' n with
|
|
|
|
| Some ts_n -> ts_n
|
|
|
|
| None ->
|
2022-09-13 16:20:13 +03:00
|
|
|
Errors.raise_spanned_error (Expr.pos e)
|
2022-07-05 11:29:24 +03:00
|
|
|
"Expression should have a sum type with at least %d cases but only \
|
2022-06-23 15:04:51 +03:00
|
|
|
has %d"
|
2022-05-31 19:38:14 +03:00
|
|
|
n (List.length ts')
|
|
|
|
in
|
2022-07-11 12:32:23 +03:00
|
|
|
let+ e1' = typecheck_expr_top_down ctx env ts_n e1 in
|
2022-08-23 16:23:52 +03:00
|
|
|
mark_with_uf (A.EInj (e1', n, e_name, ts)) (TEnum e_name)
|
2022-07-11 12:32:23 +03:00
|
|
|
| A.EMatch (e1, es, e_name) ->
|
2022-05-31 19:38:14 +03:00
|
|
|
let enum_cases =
|
2022-07-11 12:32:23 +03:00
|
|
|
List.map (fun e' -> unionfind_make ~pos:e' (TAny (Any.fresh ()))) es
|
2022-05-31 19:38:14 +03:00
|
|
|
in
|
2022-08-23 16:23:52 +03:00
|
|
|
let t_e1 = UnionFind.make (add_pos e1 (TEnum e_name)) in
|
2022-05-31 19:38:14 +03:00
|
|
|
let t_ret = unionfind_make ~pos:e (TAny (Any.fresh ())) in
|
2022-07-11 12:32:23 +03:00
|
|
|
let+ e1' = typecheck_expr_top_down ctx env t_e1 e1
|
2022-06-23 15:04:51 +03:00
|
|
|
and+ es' =
|
|
|
|
bmap2
|
|
|
|
(fun es' enum_t ->
|
2022-07-11 12:32:23 +03:00
|
|
|
typecheck_expr_top_down ctx env
|
|
|
|
(unionfind_make ~pos:es' (TArrow (enum_t, t_ret)))
|
|
|
|
es')
|
2022-05-31 19:38:14 +03:00
|
|
|
es enum_cases
|
2022-07-11 12:34:01 +03:00
|
|
|
in
|
2022-09-13 16:20:13 +03:00
|
|
|
mark (A.EMatch (e1', es', e_name)) t_ret
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EAbs (binder, taus) ->
|
2022-06-23 15:04:51 +03:00
|
|
|
if Bindlib.mbinder_arity binder <> List.length taus then
|
2022-09-13 16:20:13 +03:00
|
|
|
Errors.raise_spanned_error (Expr.pos e)
|
2022-05-31 19:38:14 +03:00
|
|
|
"function has %d variables but was supplied %d types"
|
2022-06-23 15:04:51 +03:00
|
|
|
(Bindlib.mbinder_arity binder)
|
|
|
|
(List.length taus)
|
2022-07-11 12:34:01 +03:00
|
|
|
else
|
2022-06-23 15:04:51 +03:00
|
|
|
let xs, body = Bindlib.unmbind binder in
|
2022-09-13 16:20:13 +03:00
|
|
|
let xs' = Array.map Var.translate xs in
|
2022-07-28 11:36:36 +03:00
|
|
|
let xstaus = List.mapi (fun i tau -> xs.(i), ast_to_typ tau) taus in
|
2022-05-31 19:38:14 +03:00
|
|
|
let env =
|
2022-09-26 17:32:02 +03:00
|
|
|
List.fold_left (fun env (x, tau) -> Env.add x tau env) env xstaus
|
2022-05-31 19:38:14 +03:00
|
|
|
in
|
|
|
|
let body' = typecheck_expr_bottom_up ctx env body in
|
|
|
|
let t_func =
|
|
|
|
List.fold_right
|
2022-07-11 12:32:23 +03:00
|
|
|
(fun (_, t_arg) acc -> unionfind_make (TArrow (t_arg, acc)))
|
2022-06-23 15:04:51 +03:00
|
|
|
xstaus (box_ty body')
|
2022-05-31 19:38:14 +03:00
|
|
|
in
|
2022-06-23 15:04:51 +03:00
|
|
|
let+ binder' = Bindlib.bind_mvar xs' body' in
|
2022-09-13 16:20:13 +03:00
|
|
|
mark (A.EAbs (binder', taus)) t_func
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EApp (e1, args) ->
|
2022-06-23 15:04:51 +03:00
|
|
|
let args' = bmap (typecheck_expr_bottom_up ctx env) args in
|
2022-05-31 19:38:14 +03:00
|
|
|
let t_ret = unionfind_make (TAny (Any.fresh ())) in
|
|
|
|
let t_func =
|
2021-01-13 14:04:14 +03:00
|
|
|
List.fold_right
|
2022-06-23 15:04:51 +03:00
|
|
|
(fun ty_arg acc -> unionfind_make (TArrow (ty_arg, acc)))
|
|
|
|
(Bindlib.unbox (Bindlib.box_apply (List.map ty) args'))
|
2021-01-13 14:04:14 +03:00
|
|
|
t_ret
|
2022-05-04 18:40:55 +03:00
|
|
|
in
|
2022-07-21 15:11:56 +03:00
|
|
|
let+ e1' = typecheck_expr_bottom_up ctx env e1 and+ args' in
|
2022-07-11 12:32:23 +03:00
|
|
|
unify ctx e (ty e1') t_func;
|
2022-09-13 16:20:13 +03:00
|
|
|
mark (A.EApp (e1', args')) t_ret
|
2022-06-23 15:04:51 +03:00
|
|
|
| A.EOp op as e1 -> Bindlib.box @@ mark e1 (op_type (Marked.mark pos_e op))
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EDefault (excepts, just, cons) ->
|
2022-07-11 12:32:23 +03:00
|
|
|
let just' =
|
|
|
|
typecheck_expr_top_down ctx env
|
|
|
|
(unionfind_make ~pos:just (TLit TBool))
|
2022-07-11 12:34:01 +03:00
|
|
|
just
|
|
|
|
in
|
2022-05-31 19:38:14 +03:00
|
|
|
let cons' = typecheck_expr_bottom_up ctx env cons in
|
2022-06-23 15:04:51 +03:00
|
|
|
let tau = box_ty cons' in
|
2022-07-21 15:11:56 +03:00
|
|
|
let+ just'
|
|
|
|
and+ cons'
|
2022-06-23 15:04:51 +03:00
|
|
|
and+ excepts' =
|
2022-07-11 12:32:23 +03:00
|
|
|
bmap (fun except -> typecheck_expr_top_down ctx env tau except) excepts
|
2022-07-11 12:34:01 +03:00
|
|
|
in
|
2022-05-31 19:38:14 +03:00
|
|
|
mark (A.EDefault (excepts', just', cons')) tau
|
|
|
|
| A.EIfThenElse (cond, et, ef) ->
|
2022-07-11 12:32:23 +03:00
|
|
|
let cond' =
|
|
|
|
typecheck_expr_top_down ctx env
|
|
|
|
(unionfind_make ~pos:cond (TLit TBool))
|
2022-07-11 12:34:01 +03:00
|
|
|
cond
|
|
|
|
in
|
2022-05-31 19:38:14 +03:00
|
|
|
let et' = typecheck_expr_bottom_up ctx env et in
|
2022-06-23 15:04:51 +03:00
|
|
|
let tau = box_ty et' in
|
2022-07-21 15:11:56 +03:00
|
|
|
let+ cond' and+ et' and+ ef' = typecheck_expr_top_down ctx env tau ef in
|
2022-05-31 19:38:14 +03:00
|
|
|
mark (A.EIfThenElse (cond', et', ef')) tau
|
|
|
|
| A.EAssert e1 ->
|
2022-06-23 15:04:51 +03:00
|
|
|
let+ e1' =
|
2022-07-11 12:32:23 +03:00
|
|
|
typecheck_expr_top_down ctx env (unionfind_make ~pos:e1 (TLit TBool)) e1
|
2022-07-11 12:34:01 +03:00
|
|
|
in
|
2022-05-31 19:38:14 +03:00
|
|
|
mark_with_uf (A.EAssert e1') ~pos:e1 (TLit TUnit)
|
|
|
|
| A.ErrorOnEmpty e1 ->
|
2022-06-23 15:04:51 +03:00
|
|
|
let+ e1' = typecheck_expr_bottom_up ctx env e1 in
|
2022-05-31 19:38:14 +03:00
|
|
|
mark (A.ErrorOnEmpty e1') (ty e1')
|
|
|
|
| A.EArray es ->
|
|
|
|
let cell_type = unionfind_make (TAny (Any.fresh ())) in
|
2022-06-23 15:04:51 +03:00
|
|
|
let+ es' =
|
2022-07-11 12:34:01 +03:00
|
|
|
bmap
|
2022-05-31 19:38:14 +03:00
|
|
|
(fun e1 ->
|
|
|
|
let e1' = typecheck_expr_bottom_up ctx env e1 in
|
2022-07-11 12:32:23 +03:00
|
|
|
unify ctx e1 cell_type (box_ty e1');
|
2022-07-11 12:34:01 +03:00
|
|
|
e1')
|
|
|
|
es
|
|
|
|
in
|
2022-05-31 19:38:14 +03:00
|
|
|
mark_with_uf (A.EArray es') (TArray cell_type)
|
2022-05-04 18:40:55 +03:00
|
|
|
|
2021-01-13 14:04:14 +03:00
|
|
|
(** Checks whether the expression can be typed with the provided type *)
|
2022-09-13 16:20:13 +03:00
|
|
|
and typecheck_expr_top_down :
|
|
|
|
type a.
|
|
|
|
A.decl_ctx ->
|
2022-09-26 17:32:02 +03:00
|
|
|
(a, 'm A.mark) A.gexpr Env.t ->
|
2022-09-13 16:20:13 +03:00
|
|
|
unionfind_typ ->
|
|
|
|
(a, 'm A.mark) A.gexpr ->
|
|
|
|
(a, mark) A.gexpr Bindlib.box =
|
|
|
|
fun ctx env tau e ->
|
2022-08-25 20:46:13 +03:00
|
|
|
(* Cli.debug_format "Propagating type %a for naked_expr %a" (format_typ ctx)
|
|
|
|
tau (Expr.format ctx) e; *)
|
2022-09-13 16:20:13 +03:00
|
|
|
let pos_e = Expr.pos e in
|
2022-07-28 11:36:36 +03:00
|
|
|
let mark e = Marked.mark { uf = tau; pos = pos_e } e in
|
2022-09-14 18:56:27 +03:00
|
|
|
let unify_and_mark tau' f_e =
|
2022-09-20 19:32:38 +03:00
|
|
|
unify ctx e tau' tau;
|
2022-09-14 18:56:27 +03:00
|
|
|
Bindlib.box_apply (Marked.mark { uf = tau; pos = pos_e }) (f_e ())
|
2022-07-11 12:34:01 +03:00
|
|
|
in
|
2022-05-31 19:38:14 +03:00
|
|
|
let unionfind_make ?(pos = e) t = UnionFind.make (add_pos pos t) in
|
2022-05-30 12:20:48 +03:00
|
|
|
match Marked.unmark e with
|
2022-09-26 17:32:02 +03:00
|
|
|
| A.ELocation loc as e1 -> (
|
2022-09-14 16:36:24 +03:00
|
|
|
let ty =
|
|
|
|
match loc with
|
|
|
|
| DesugaredScopeVar (v, _) | ScopelangScopeVar v ->
|
2022-09-26 17:32:02 +03:00
|
|
|
Env.get_scope_var env (Marked.unmark v)
|
|
|
|
| SubScopeVar (scope, _, v) ->
|
|
|
|
Env.get_subscope_var env scope (Marked.unmark v)
|
2022-09-14 16:36:24 +03:00
|
|
|
in
|
2022-09-26 17:32:02 +03:00
|
|
|
match ty with
|
|
|
|
| Some ty -> unify_and_mark (ast_to_typ ty) (fun () -> Bindlib.box e1)
|
|
|
|
| None ->
|
|
|
|
Errors.raise_spanned_error pos_e "Reference to %a not found"
|
|
|
|
(Expr.format ctx) e)
|
2022-09-14 16:36:24 +03:00
|
|
|
| A.EStruct (s_name, fmap) ->
|
2022-09-14 18:56:27 +03:00
|
|
|
unify_and_mark (unionfind_make (TStruct s_name))
|
|
|
|
@@ fun () ->
|
2022-09-14 16:36:24 +03:00
|
|
|
let+ fmap' =
|
|
|
|
(* This assumes that the fields in fmap and the struct type are already
|
|
|
|
ensured to be the same *)
|
|
|
|
List.fold_left
|
|
|
|
(fun fmap' (f_name, f_ty) ->
|
|
|
|
let f_e = A.StructFieldMap.find f_name fmap in
|
|
|
|
let+ fmap'
|
|
|
|
and+ f_e' = typecheck_expr_top_down ctx env (ast_to_typ f_ty) f_e in
|
|
|
|
A.StructFieldMap.add f_name f_e' fmap')
|
|
|
|
(Bindlib.box A.StructFieldMap.empty)
|
|
|
|
(A.StructMap.find s_name ctx.A.ctx_structs)
|
|
|
|
in
|
2022-09-14 18:56:27 +03:00
|
|
|
A.EStruct (s_name, fmap')
|
2022-09-14 16:36:24 +03:00
|
|
|
| A.EStructAccess (e_struct, f_name, s_name) ->
|
2022-09-14 18:56:27 +03:00
|
|
|
unify_and_mark
|
|
|
|
(ast_to_typ
|
|
|
|
(List.assoc f_name (A.StructMap.find s_name ctx.A.ctx_structs)))
|
|
|
|
@@ fun () ->
|
2022-09-14 16:36:24 +03:00
|
|
|
let+ e_struct' =
|
|
|
|
typecheck_expr_top_down ctx env (unionfind_make (TStruct s_name)) e_struct
|
|
|
|
in
|
2022-09-14 18:56:27 +03:00
|
|
|
A.EStructAccess (e_struct', f_name, s_name)
|
2022-09-14 16:36:24 +03:00
|
|
|
| A.EEnumInj (e_enum, c_name, e_name) ->
|
2022-09-14 18:56:27 +03:00
|
|
|
unify_and_mark
|
|
|
|
(ast_to_typ (List.assoc c_name (A.EnumMap.find e_name ctx.A.ctx_enums)))
|
|
|
|
@@ fun () ->
|
2022-09-14 16:36:24 +03:00
|
|
|
let+ e_enum' =
|
|
|
|
typecheck_expr_top_down ctx env (unionfind_make (TEnum e_name)) e_enum
|
|
|
|
in
|
2022-09-14 18:56:27 +03:00
|
|
|
A.EEnumInj (e_enum', c_name, e_name)
|
2022-09-14 16:36:24 +03:00
|
|
|
| A.EMatchS (e1, e_name, cases) ->
|
|
|
|
let cases_ty = A.EnumMap.find e_name ctx.A.ctx_enums in
|
|
|
|
let t_ret = unionfind_make ~pos:e1 (TAny (Any.fresh ())) in
|
2022-09-14 18:56:27 +03:00
|
|
|
unify_and_mark t_ret
|
|
|
|
@@ fun () ->
|
2022-09-14 16:36:24 +03:00
|
|
|
let+ e1' =
|
|
|
|
typecheck_expr_top_down ctx env (unionfind_make (TEnum e_name)) e1
|
|
|
|
and+ cases' =
|
|
|
|
A.EnumConstructorMap.fold
|
|
|
|
(fun c_name e cases' ->
|
|
|
|
let c_ty = List.assoc c_name cases_ty in
|
|
|
|
let e_ty = unionfind_make ~pos:e (TArrow (ast_to_typ c_ty, t_ret)) in
|
|
|
|
let+ cases' and+ e' = typecheck_expr_top_down ctx env e_ty e in
|
|
|
|
A.EnumConstructorMap.add c_name e' cases')
|
|
|
|
cases
|
|
|
|
(Bindlib.box A.EnumConstructorMap.empty)
|
|
|
|
in
|
2022-09-14 18:56:27 +03:00
|
|
|
A.EMatchS (e1', e_name, cases')
|
|
|
|
| A.ERaise _ as e1 -> Bindlib.box (mark e1)
|
2022-09-14 16:36:24 +03:00
|
|
|
| A.ECatch (e1, ex, e2) ->
|
2022-09-14 18:56:27 +03:00
|
|
|
let+ e1' = typecheck_expr_top_down ctx env tau e1
|
|
|
|
and+ e2' = typecheck_expr_top_down ctx env tau e2 in
|
|
|
|
mark (A.ECatch (e1', ex, e2'))
|
|
|
|
| A.EVar v ->
|
|
|
|
let tau' =
|
2022-09-26 17:32:02 +03:00
|
|
|
match Env.get env v with
|
|
|
|
| Some t -> t
|
|
|
|
| None ->
|
2022-09-14 18:56:27 +03:00
|
|
|
Errors.raise_spanned_error pos_e
|
|
|
|
"Variable %s not found in the current context" (Bindlib.name_of v)
|
|
|
|
in
|
|
|
|
unify_and_mark tau' @@ fun () -> Bindlib.box_var (Var.translate v)
|
2022-09-13 16:20:13 +03:00
|
|
|
| A.ELit lit as e1 ->
|
2022-09-14 18:56:27 +03:00
|
|
|
unify_and_mark (unionfind_make (lit_type lit))
|
|
|
|
@@ fun () -> Bindlib.box @@ e1
|
2022-08-23 16:23:52 +03:00
|
|
|
| A.ETuple (es, None) ->
|
2022-09-14 18:56:27 +03:00
|
|
|
let tys = List.map (fun _ -> unionfind_make (TAny (Any.fresh ()))) es in
|
|
|
|
unify_and_mark (unionfind_make (TTuple tys))
|
|
|
|
@@ fun () ->
|
|
|
|
let+ es' = bmap2 (typecheck_expr_top_down ctx env) tys es in
|
|
|
|
A.ETuple (es', None)
|
2022-08-23 16:23:52 +03:00
|
|
|
| A.ETuple (es, Some s_name) ->
|
|
|
|
let tys =
|
|
|
|
List.map
|
|
|
|
(fun (_, ty) -> ast_to_typ ty)
|
|
|
|
(A.StructMap.find s_name ctx.A.ctx_structs)
|
|
|
|
in
|
2022-09-14 18:56:27 +03:00
|
|
|
unify_and_mark (unionfind_make (TStruct s_name))
|
|
|
|
@@ fun () ->
|
2022-08-23 16:23:52 +03:00
|
|
|
let+ es' = bmap2 (typecheck_expr_top_down ctx env) tys es in
|
2022-09-14 18:56:27 +03:00
|
|
|
A.ETuple (es', Some s_name)
|
|
|
|
| A.ETupleAccess (e1, n, s, typs) ->
|
2022-06-23 15:04:51 +03:00
|
|
|
let typs' = List.map ast_to_typ typs in
|
2022-08-23 16:23:52 +03:00
|
|
|
let tuple_ty = match s with None -> TTuple typs' | Some s -> TStruct s in
|
2022-09-14 18:56:27 +03:00
|
|
|
let t1n =
|
|
|
|
try List.nth typs' n
|
|
|
|
with Not_found ->
|
|
|
|
Errors.raise_spanned_error (Expr.pos e1)
|
|
|
|
"Expression should have a tuple type with at least %d elements but \
|
|
|
|
only has %d"
|
|
|
|
n (List.length typs)
|
|
|
|
in
|
|
|
|
unify_and_mark t1n
|
|
|
|
@@ fun () ->
|
2022-08-23 16:23:52 +03:00
|
|
|
let+ e1' = typecheck_expr_top_down ctx env (unionfind_make tuple_ty) e1 in
|
2022-09-14 18:56:27 +03:00
|
|
|
A.ETupleAccess (e1', n, s, typs)
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EInj (e1, n, e_name, ts) ->
|
2022-06-23 15:04:51 +03:00
|
|
|
let ts' = List.map ast_to_typ ts in
|
|
|
|
let ts_n =
|
2022-09-14 18:56:27 +03:00
|
|
|
try List.nth ts' n
|
|
|
|
with Not_found ->
|
2022-09-13 16:20:13 +03:00
|
|
|
Errors.raise_spanned_error (Expr.pos e)
|
2022-03-08 15:04:27 +03:00
|
|
|
"Expression should have a sum type with at least %d cases but only \
|
2022-07-11 12:34:01 +03:00
|
|
|
has %d"
|
2022-03-08 15:04:27 +03:00
|
|
|
n (List.length ts)
|
2022-07-11 12:34:01 +03:00
|
|
|
in
|
2022-09-14 18:56:27 +03:00
|
|
|
unify_and_mark (unionfind_make (TEnum e_name))
|
|
|
|
@@ fun () ->
|
2022-07-11 12:32:23 +03:00
|
|
|
let+ e1' = typecheck_expr_top_down ctx env ts_n e1 in
|
2022-09-14 18:56:27 +03:00
|
|
|
A.EInj (e1', n, e_name, ts)
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EMatch (e1, es, e_name) ->
|
2022-09-14 18:56:27 +03:00
|
|
|
let+ es' =
|
2022-07-11 12:34:01 +03:00
|
|
|
bmap2
|
2022-09-14 18:56:27 +03:00
|
|
|
(fun es' (_, c_ty) ->
|
2022-07-11 12:32:23 +03:00
|
|
|
typecheck_expr_top_down ctx env
|
2022-09-14 18:56:27 +03:00
|
|
|
(unionfind_make ~pos:es' (TArrow (ast_to_typ c_ty, tau)))
|
2022-07-11 12:34:01 +03:00
|
|
|
es')
|
2022-09-14 18:56:27 +03:00
|
|
|
es
|
|
|
|
(A.EnumMap.find e_name ctx.ctx_enums)
|
|
|
|
and+ e1' =
|
|
|
|
typecheck_expr_top_down ctx env (unionfind_make ~pos:e1 (TEnum e_name)) e1
|
2022-07-11 12:34:01 +03:00
|
|
|
in
|
2022-09-14 18:56:27 +03:00
|
|
|
mark (A.EMatch (e1', es', e_name))
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EAbs (binder, t_args) ->
|
2022-06-23 15:04:51 +03:00
|
|
|
if Bindlib.mbinder_arity binder <> List.length t_args then
|
2022-09-13 16:20:13 +03:00
|
|
|
Errors.raise_spanned_error (Expr.pos e)
|
2022-05-31 19:38:14 +03:00
|
|
|
"function has %d variables but was supplied %d types"
|
2022-06-23 15:04:51 +03:00
|
|
|
(Bindlib.mbinder_arity binder)
|
|
|
|
(List.length t_args)
|
2022-05-31 19:38:14 +03:00
|
|
|
else
|
2022-09-14 18:56:27 +03:00
|
|
|
let tau_args = List.map ast_to_typ t_args in
|
|
|
|
let t_ret = unionfind_make (TAny (Any.fresh ())) in
|
|
|
|
let t_func =
|
|
|
|
List.fold_right
|
|
|
|
(fun t_arg acc -> unionfind_make (TArrow (t_arg, acc)))
|
|
|
|
tau_args t_ret
|
|
|
|
in
|
|
|
|
unify_and_mark t_func
|
|
|
|
@@ fun () ->
|
2022-06-23 15:04:51 +03:00
|
|
|
let xs, body = Bindlib.unmbind binder in
|
2022-09-13 16:20:13 +03:00
|
|
|
let xs' = Array.map Var.translate xs in
|
2022-06-23 15:04:51 +03:00
|
|
|
let env =
|
2022-09-26 17:32:02 +03:00
|
|
|
List.fold_left2
|
|
|
|
(fun env x tau_arg -> Env.add x tau_arg env)
|
|
|
|
env (Array.to_list xs) tau_args
|
2022-05-04 18:40:55 +03:00
|
|
|
in
|
2022-09-14 18:56:27 +03:00
|
|
|
let body' = typecheck_expr_top_down ctx env t_ret body in
|
2022-06-23 15:04:51 +03:00
|
|
|
let+ binder' = Bindlib.bind_mvar xs' body' in
|
2022-09-14 18:56:27 +03:00
|
|
|
A.EAbs (binder', t_args)
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EApp (e1, args) ->
|
2022-09-14 18:56:27 +03:00
|
|
|
let t_args =
|
|
|
|
List.map (fun _ -> unionfind_make (TAny (Any.fresh ()))) args
|
|
|
|
in
|
2021-01-13 14:04:14 +03:00
|
|
|
let t_func =
|
2020-12-30 00:26:10 +03:00
|
|
|
List.fold_right
|
2022-09-14 18:56:27 +03:00
|
|
|
(fun t_arg acc -> unionfind_make (TArrow (t_arg, acc)))
|
|
|
|
t_args tau
|
2022-05-04 18:40:55 +03:00
|
|
|
in
|
2022-09-14 18:56:27 +03:00
|
|
|
let+ e1' = typecheck_expr_top_down ctx env t_func e1
|
|
|
|
and+ args' = bmap2 (typecheck_expr_top_down ctx env) t_args args in
|
|
|
|
mark (A.EApp (e1', args'))
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EOp op as e1 ->
|
2022-09-14 18:56:27 +03:00
|
|
|
unify_and_mark (op_type (add_pos e op)) @@ fun () -> Bindlib.box e1
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.EDefault (excepts, just, cons) ->
|
2022-09-14 18:56:27 +03:00
|
|
|
let+ cons' = typecheck_expr_top_down ctx env tau cons
|
|
|
|
and+ just' =
|
2022-07-11 12:32:23 +03:00
|
|
|
typecheck_expr_top_down ctx env
|
|
|
|
(unionfind_make ~pos:just (TLit TBool))
|
|
|
|
just
|
|
|
|
and+ excepts' = bmap (typecheck_expr_top_down ctx env tau) excepts in
|
2022-05-31 19:38:14 +03:00
|
|
|
mark (A.EDefault (excepts', just', cons'))
|
|
|
|
| A.EIfThenElse (cond, et, ef) ->
|
2022-09-14 18:56:27 +03:00
|
|
|
let+ et' = typecheck_expr_top_down ctx env tau et
|
|
|
|
and+ ef' = typecheck_expr_top_down ctx env tau ef
|
|
|
|
and+ cond' =
|
2022-07-11 12:32:23 +03:00
|
|
|
typecheck_expr_top_down ctx env
|
|
|
|
(unionfind_make ~pos:cond (TLit TBool))
|
|
|
|
cond
|
2022-09-14 18:56:27 +03:00
|
|
|
in
|
2022-05-31 19:38:14 +03:00
|
|
|
mark (A.EIfThenElse (cond', et', ef'))
|
|
|
|
| A.EAssert e1 ->
|
2022-09-14 18:56:27 +03:00
|
|
|
unify_and_mark (unionfind_make ~pos:e1 (TLit TUnit))
|
|
|
|
@@ fun () ->
|
2022-07-11 12:32:23 +03:00
|
|
|
let+ e1' =
|
|
|
|
typecheck_expr_top_down ctx env (unionfind_make ~pos:e1 (TLit TBool)) e1
|
|
|
|
in
|
2022-09-14 18:56:27 +03:00
|
|
|
A.EAssert e1'
|
2022-05-31 19:38:14 +03:00
|
|
|
| A.ErrorOnEmpty e1 ->
|
2022-07-11 12:32:23 +03:00
|
|
|
let+ e1' = typecheck_expr_top_down ctx env tau e1 in
|
2022-05-31 19:38:14 +03:00
|
|
|
mark (A.ErrorOnEmpty e1')
|
|
|
|
| A.EArray es ->
|
|
|
|
let cell_type = unionfind_make (TAny (Any.fresh ())) in
|
2022-09-14 18:56:27 +03:00
|
|
|
unify_and_mark (unionfind_make (TArray cell_type))
|
|
|
|
@@ fun () ->
|
|
|
|
let+ es' = bmap (typecheck_expr_top_down ctx env cell_type) es in
|
|
|
|
A.EArray es'
|
2022-07-11 12:32:23 +03:00
|
|
|
|
|
|
|
let wrap ctx f e =
|
2022-09-16 19:29:27 +03:00
|
|
|
try
|
|
|
|
Bindlib.unbox (f e)
|
|
|
|
(* We need to unbox here, because the typing may otherwise be stored in
|
|
|
|
Bindlib closures and not yet applied, and would escape the `try..with` *)
|
2022-07-11 12:32:23 +03:00
|
|
|
with Type_error (e, ty1, ty2) -> (
|
|
|
|
let bt = Printexc.get_raw_backtrace () in
|
|
|
|
try handle_type_error ctx e ty1 ty2
|
|
|
|
with e -> Printexc.raise_with_backtrace e bt)
|
2020-11-23 12:44:06 +03:00
|
|
|
|
2020-12-14 20:09:38 +03:00
|
|
|
(** {1 API} *)
|
|
|
|
|
2022-07-28 11:36:36 +03:00
|
|
|
let get_ty_mark { uf; pos } = A.Typed { ty = typ_to_ast uf; pos }
|
2022-07-19 16:19:06 +03:00
|
|
|
|
2020-12-14 20:09:38 +03:00
|
|
|
(* Infer the type of an expression *)
|
2022-09-26 13:12:39 +03:00
|
|
|
let expr
|
|
|
|
(type a)
|
|
|
|
(ctx : A.decl_ctx)
|
2022-09-26 17:32:02 +03:00
|
|
|
?(env = Env.empty)
|
2022-09-26 13:12:39 +03:00
|
|
|
?(typ : A.typ option)
|
|
|
|
(e : (a, 'm) A.gexpr) : (a, A.typed A.mark) A.gexpr A.box =
|
|
|
|
let fty =
|
|
|
|
match typ with
|
|
|
|
| None -> typecheck_expr_bottom_up ctx env
|
|
|
|
| Some typ -> typecheck_expr_top_down ctx env (ast_to_typ typ)
|
|
|
|
in
|
|
|
|
Expr.map_marks ~f:get_ty_mark (wrap ctx fty e)
|
|
|
|
|
|
|
|
let rec scope_body_expr ctx env ty_out body_expr =
|
|
|
|
match body_expr with
|
|
|
|
| A.Result e ->
|
|
|
|
let e' = wrap ctx (typecheck_expr_top_down ctx env ty_out) e in
|
|
|
|
let e' = Expr.map_marks ~f:get_ty_mark e' in
|
|
|
|
Bindlib.box_apply (fun e -> A.Result e) e'
|
|
|
|
| A.ScopeLet
|
|
|
|
{
|
|
|
|
scope_let_kind;
|
|
|
|
scope_let_typ;
|
|
|
|
scope_let_expr = e0;
|
|
|
|
scope_let_next;
|
|
|
|
scope_let_pos;
|
|
|
|
} ->
|
|
|
|
let ty_e = ast_to_typ scope_let_typ in
|
|
|
|
let e = wrap ctx (typecheck_expr_bottom_up ctx env) e0 in
|
|
|
|
wrap ctx (fun t -> Bindlib.box (unify ctx e0 (ty e) t)) ty_e;
|
|
|
|
(* We could use [typecheck_expr_top_down] rather than this manual
|
|
|
|
unification, but we get better messages with this order of the [unify]
|
|
|
|
parameters, which keeps location of the type as defined instead of as
|
|
|
|
inferred. *)
|
|
|
|
let var, next = Bindlib.unbind scope_let_next in
|
2022-09-26 17:32:02 +03:00
|
|
|
let env = Env.add var ty_e env in
|
2022-09-26 13:12:39 +03:00
|
|
|
let next = scope_body_expr ctx env ty_out next in
|
|
|
|
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
|
|
|
|
Bindlib.box_apply2
|
|
|
|
(fun scope_let_expr scope_let_next ->
|
|
|
|
A.ScopeLet
|
|
|
|
{
|
|
|
|
scope_let_kind;
|
|
|
|
scope_let_typ;
|
|
|
|
scope_let_expr;
|
|
|
|
scope_let_next;
|
|
|
|
scope_let_pos;
|
|
|
|
})
|
|
|
|
(Expr.map_marks ~f:get_ty_mark e)
|
|
|
|
scope_let_next
|
|
|
|
|
|
|
|
let scope_body ctx env body =
|
|
|
|
let get_pos struct_name =
|
|
|
|
Marked.get_mark (A.StructName.get_info struct_name)
|
2022-05-31 19:38:14 +03:00
|
|
|
in
|
2022-09-26 13:12:39 +03:00
|
|
|
let struct_ty struct_name =
|
|
|
|
UnionFind.make (Marked.mark (get_pos struct_name) (TStruct struct_name))
|
|
|
|
in
|
|
|
|
let ty_in = struct_ty body.A.scope_body_input_struct in
|
|
|
|
let ty_out = struct_ty body.A.scope_body_output_struct in
|
|
|
|
let var, e = Bindlib.unbind body.A.scope_body_expr in
|
2022-09-26 17:32:02 +03:00
|
|
|
let env = Env.add var ty_in env in
|
2022-09-26 13:12:39 +03:00
|
|
|
let e' = scope_body_expr ctx env ty_out e in
|
|
|
|
( Bindlib.bind_var (Var.translate var) e',
|
|
|
|
UnionFind.make
|
|
|
|
(Marked.mark
|
|
|
|
(get_pos body.A.scope_body_output_struct)
|
|
|
|
(TArrow (ty_in, ty_out))) )
|
|
|
|
|
|
|
|
let rec scopes ctx env = function
|
|
|
|
| A.Nil -> Bindlib.box A.Nil
|
|
|
|
| A.ScopeDef def ->
|
|
|
|
let body_e, ty_scope = scope_body ctx env def.scope_body in
|
|
|
|
let scope_next =
|
|
|
|
let scope_var, next = Bindlib.unbind def.scope_next in
|
2022-09-26 17:32:02 +03:00
|
|
|
let env = Env.add scope_var ty_scope env in
|
2022-09-26 13:12:39 +03:00
|
|
|
let next' = scopes ctx env next in
|
|
|
|
Bindlib.bind_var (Var.translate scope_var) next'
|
|
|
|
in
|
|
|
|
Bindlib.box_apply2
|
|
|
|
(fun scope_body_expr scope_next ->
|
|
|
|
A.ScopeDef
|
|
|
|
{
|
|
|
|
def with
|
|
|
|
scope_body = { def.scope_body with scope_body_expr };
|
|
|
|
scope_next;
|
|
|
|
})
|
|
|
|
body_e scope_next
|
|
|
|
|
|
|
|
let program prg =
|
2022-09-26 17:32:02 +03:00
|
|
|
let scopes = Bindlib.unbox (scopes prg.A.decl_ctx Env.empty prg.A.scopes) in
|
2022-09-26 13:12:39 +03:00
|
|
|
{ prg with scopes }
|