diff --git a/src/catala/default_calculus/typing.ml b/src/catala/default_calculus/typing.ml index f0194527..3422f6c5 100644 --- a/src/catala/default_calculus/typing.ml +++ b/src/catala/default_calculus/typing.ml @@ -68,31 +68,45 @@ let rec format_typ (fmt : Format.formatter) (typ : typ Pos.marked UnionFind.elem (** Raises an error if unification cannot be performed *) let rec unify (t1 : typ Pos.marked UnionFind.elem) (t2 : typ Pos.marked UnionFind.elem) : unit = - (* Cli.debug_print (Format.asprintf "Unifying %a and %a" format_typ t1 format_typ t2); *) + Cli.debug_print (Format.asprintf "Unifying %a and %a" format_typ t1 format_typ t2); let t1_repr = UnionFind.get (UnionFind.find t1) in let t2_repr = UnionFind.get (UnionFind.find t2) in - match (t1_repr, t2_repr) with - | (TLit tl1, _), (TLit tl2, _) when tl1 = tl2 -> () - | (TArrow (t11, t12), _), (TArrow (t21, t22), _) -> - unify t11 t21; - unify t12 t22 - | (TTuple ts1, _), (TTuple ts2, _) -> List.iter2 unify ts1 ts2 - | (TEnum ts1, _), (TEnum ts2, _) -> List.iter2 unify ts1 ts2 - | (TArray t1', _), (TArray t2', _) -> unify t1' t2' - | (TAny _, _), (TAny _, _) -> ignore (UnionFind.union t1 t2) - | (TAny _, _), t_repr | t_repr, (TAny _, _) -> - let t_union = UnionFind.union t1 t2 in - ignore (UnionFind.set t_union t_repr) - | (_, t1_pos), (_, t2_pos) -> - (* TODO: if we get weird error messages, then it means that we should use the persistent - version of the union-find data structure. *) - Errors.raise_multispanned_error - (Format.asprintf "Error during typechecking, types %a and %a are incompatible" format_typ t1 - format_typ t2) - [ - (Some (Format.asprintf "Type %a coming from expression:" format_typ t1), t1_pos); - (Some (Format.asprintf "Type %a coming from expression:" format_typ t2), t2_pos); - ] + let repr = + match (t1_repr, t2_repr) with + | (TLit tl1, _), (TLit tl2, _) when tl1 = tl2 -> None + | (TArrow (t11, t12), _), (TArrow (t21, t22), _) -> + unify t11 t21; + unify t12 t22; + None + | (TTuple ts1, _), (TTuple ts2, _) -> + List.iter2 unify ts1 ts2; + None + | (TEnum ts1, _), (TEnum ts2, _) -> + List.iter2 unify ts1 ts2; + None + | (TArray t1', _), (TArray t2', _) -> + unify t1' t2'; + None + | (TAny _, _), (TAny _, _) -> None + | (TAny _, _), t_repr | t_repr, (TAny _, _) -> Some t_repr + | (_, t1_pos), (_, t2_pos) -> + (* TODO: if we get weird error messages, then it means that we should use the persistent + version of the union-find data structure. *) + Errors.raise_multispanned_error + (Format.asprintf "Error during typechecking, types %a and %a are incompatible" format_typ + t1 format_typ t2) + [ + (Some (Format.asprintf "Type %a coming from expression:" format_typ t1), t1_pos); + (Some (Format.asprintf "Type %a coming from expression:" format_typ t2), t2_pos); + ] + in + let t_union = UnionFind.union t1 t2 in + match repr with + | None -> () + | Some t_repr -> + Cli.debug_print + (Format.asprintf "Setting %a to %a" format_typ t_union format_typ (UnionFind.make t_repr)); + UnionFind.set t_union t_repr (** Operators have a single type, instead of being polymorphic with constraints. This allows us to have a simpler type system, while we argue the syntactic burden of operator annotations helps @@ -236,18 +250,16 @@ let rec typecheck_expr_bottom_up (env : env) (e : A.expr Pos.marked) : typ Pos.m | EAbs (pos_binder, binder, taus) -> let xs, body = Bindlib.unmbind binder in if Array.length xs = List.length taus then - let xstaus = List.map2 (fun x tau -> (x, tau)) (Array.to_list xs) taus in - let env = - List.fold_left - (fun env (x, tau) -> - A.VarMap.add x (ast_to_typ (Pos.unmark tau), Pos.get_position tau) env) - env xstaus + let xstaus = + List.map2 + (fun x tau -> (x, (ast_to_typ (Pos.unmark tau), Pos.get_position tau))) + (Array.to_list xs) taus in + let env = List.fold_left (fun env (x, tau) -> A.VarMap.add x tau env) env xstaus in List.fold_right - (fun t_arg (acc : typ Pos.marked UnionFind.elem) -> - UnionFind.make - (TArrow (UnionFind.make (Pos.map_under_mark ast_to_typ t_arg), acc), pos_binder)) - taus + (fun (_, t_arg) (acc : typ Pos.marked UnionFind.elem) -> + UnionFind.make (TArrow (UnionFind.make t_arg, acc), pos_binder)) + xstaus (typecheck_expr_bottom_up env body) else Errors.raise_spanned_error @@ -287,13 +299,13 @@ let rec typecheck_expr_bottom_up (env : env) (e : A.expr Pos.marked) : typ Pos.m es; UnionFind.make (Pos.same_pos_as (TArray cell_type) e) in - (* Cli.debug_print (Format.asprintf "Found type of %a: %a" Print.format_expr e format_typ out); *) + Cli.debug_print (Format.asprintf "Found type of %a: %a" Print.format_expr e format_typ out); out (** Checks whether the expression can be typed with the provided type *) and typecheck_expr_top_down (env : env) (e : A.expr Pos.marked) (tau : typ Pos.marked UnionFind.elem) : unit = - (* Cli.debug_print (Format.asprintf "Typechecking %a : %a" Print.format_expr e format_typ tau); *) + Cli.debug_print (Format.asprintf "Typechecking %a : %a" Print.format_expr e format_typ tau); match Pos.unmark e with | EVar v -> ( match A.VarMap.find_opt (Pos.unmark v) env with @@ -376,12 +388,12 @@ and typecheck_expr_top_down (env : env) (e : A.expr Pos.marked) | EAbs (pos_binder, binder, t_args) -> let xs, body = Bindlib.unmbind binder in if Array.length xs = List.length t_args then - let xstaus = List.map2 (fun x t_arg -> (x, t_arg)) (Array.to_list xs) t_args in - let env = - List.fold_left - (fun env (x, t_arg) -> A.VarMap.add x (ast_to_typ (Pos.unmark t_arg), pos_binder) env) - env xstaus + let xstaus = + List.map2 + (fun x t_arg -> (x, Pos.map_under_mark ast_to_typ t_arg)) + (Array.to_list xs) t_args in + let env = List.fold_left (fun env (x, t_arg) -> A.VarMap.add x t_arg env) env xstaus in let t_out = typecheck_expr_bottom_up env body in let t_func = List.fold_right diff --git a/tests/test_array/fold_error.catala b/tests/test_array/fold_error.catala new file mode 100644 index 00000000..f2179c06 --- /dev/null +++ b/tests/test_array/fold_error.catala @@ -0,0 +1,11 @@ +@Article@ + +/* +new scope A: + param list content set int + param list_high_count content int + +scope A: + def list := [0; 5; 6; 7; 1; 64; 12] + def list_high_count := number for m in list of (m >=$ $7) +*/ \ No newline at end of file diff --git a/tests/test_money/no_mingle.catala.A.out b/tests/test_money/no_mingle.catala.A.out index 54266e36..f22b4673 100644 --- a/tests/test_money/no_mingle.catala.A.out +++ b/tests/test_money/no_mingle.catala.A.out @@ -9,5 +9,5 @@ [ERROR] Type money coming from expression: [ERROR] --> test_money/no_mingle.catala [ERROR] | -[ERROR] 4 | new scope A: -[ERROR] | ^ +[ERROR] 6 | param y content money +[ERROR] | ^^^^^