Typing: add a "assume operator types" mode (#575)

This commit is contained in:
Louis Gesbert 2024-02-09 18:07:22 +01:00 committed by GitHub
commit 68aaf6e2f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 189 additions and 209 deletions

View File

@ -461,7 +461,6 @@ let base_bindings catala_exe catala_flags build_dir include_dirs =
Nj.binding Var.clerk_flags
("-e"
:: Var.(!catala_exe)
:: ("--build-dir=" ^ Var.(!builddir))
:: includes
@ List.map (fun f -> "--catala-opts=" ^ f) catala_flags);
Nj.binding Var.ocamlopt_exe ["ocamlopt"];
@ -498,7 +497,7 @@ let[@ocamlformat "disable"] static_base_rules =
Nj.rule "out-test"
~command: [
!catala_exe; !test_command; !catala_flags; !input;
!catala_exe; !test_command; "--plugin-dir="; "-o -"; !catala_flags; !input;
">"; !output; "2>&1";
"||"; "true";
]
@ -985,16 +984,15 @@ let run_cmd =
$ Cli.ninja_flags)
let runtest_cmd =
let run catala_exe catala_opts build_dir include_dirs file =
let run catala_exe catala_opts include_dirs file =
let catala_opts =
List.fold_left
(fun opts dir -> "-I" :: dir :: opts)
catala_opts include_dirs
in
let build_dir = Poll.build_dir ?dir:build_dir () in
Clerk_runtest.run_inline_tests
(Option.value ~default:"catala" catala_exe)
catala_opts build_dir file;
catala_opts file;
0
in
let doc =
@ -1006,7 +1004,6 @@ let runtest_cmd =
const run
$ Cli.catala_exe
$ Cli.catala_opts
$ Cli.build_dir
$ Cli.include_dirs
$ Cli.single_file)

View File

@ -16,7 +16,7 @@
open Catala_utils
let run_catala_test catala_exe catala_opts build_dir file program args oc =
let run_catala_test catala_exe catala_opts file program args oc =
let cmd_in_rd, cmd_in_wr = Unix.pipe () in
Unix.set_close_on_exec cmd_in_wr;
let command_oc = Unix.out_channel_of_descr cmd_in_wr in
@ -41,7 +41,6 @@ let run_catala_test catala_exe catala_opts build_dir file program args oc =
|> Seq.cons "CATALA_OUT=-"
(* |> Seq.cons "CATALA_COLOR=never" *)
|> Seq.cons "CATALA_PLUGINS="
|> Seq.cons ("CATALA_BUILD_DIR=" ^ build_dir)
|> Array.of_seq
in
flush oc;
@ -59,7 +58,7 @@ let run_catala_test catala_exe catala_opts build_dir file program args oc =
(** Directly runs the test (not using ninja, this will be called by ninja rules
through the "clerk runtest" command) *)
let run_inline_tests catala_exe catala_opts build_dir filename =
let run_inline_tests catala_exe catala_opts filename =
let module L = Surface.Lexer_common in
let lang =
match Clerk_scan.get_lang filename with
@ -95,7 +94,7 @@ let run_inline_tests catala_exe catala_opts build_dir filename =
skip_block lines
| Some args ->
let args = String.split_on_char ' ' args in
run_catala_test catala_exe catala_opts build_dir filename
run_catala_test catala_exe catala_opts filename
lines_until_now args oc;
skip_block lines)
and skip_block lines =

View File

@ -22,7 +22,7 @@
open Catala_utils
val run_inline_tests : string -> string list -> File.t -> File.t -> unit
(** [run_inline_tests catala_exe catala_opts build_dir file] runs the tests in
val run_inline_tests : string -> string list -> File.t -> unit
(** [run_inline_tests catala_exe catala_opts file] runs the tests in
Catala [file] using the given path to the Catala executable and the provided
options. Output is printed to [stdout]. *)

View File

@ -23,7 +23,7 @@ let expr ctx env e =
[Some] *)
(* Intermediate unboxings are fine since the [check_expr] will rebox in
depth *)
Typing.check_expr ~leave_unresolved:ErrorOnAny ctx ~env (Expr.unbox e)
Typing.check_expr ctx ~env (Expr.unbox e)
let rule ctx env rule =
let env =

View File

@ -192,7 +192,7 @@ module Passes = struct
match typed with
| Typed _ -> (
Message.emit_debug "Typechecking again...";
try Typing.program ~leave_unresolved:ErrorOnAny prg
try Typing.program prg
with Message.CompilerError error_content ->
let bt = Printexc.get_raw_backtrace () in
Printexc.raise_with_backtrace
@ -257,7 +257,7 @@ module Passes = struct
let prg =
if not closure_conversion then (
Message.emit_debug "Retyping lambda calculus...";
Typing.program ~leave_unresolved:LeaveAny prg)
Typing.program ~fail_on_any:false prg)
else (
Message.emit_debug "Performing closure conversion...";
let prg = Lcalc.Closure_conversion.closure_conversion prg in
@ -268,16 +268,15 @@ module Passes = struct
else prg
in
Message.emit_debug "Retyping lambda calculus...";
Typing.program ~leave_unresolved:LeaveAny prg)
Typing.program ~fail_on_any:false prg)
in
let prg, type_ordering =
if monomorphize_types then (
Message.emit_debug "Monomorphizing types...";
Lcalc.Monomorphize.program prg
(* (* FIXME: typing no longer works after monomorphisation, it would
* need special operator handling for arrays and options *)
* Message.emit_debug "Retyping lambda calculus...";
* let prg = Typing.program ~leave_unresolved:LeaveAny prg in *))
let prg, type_ordering = Lcalc.Monomorphize.program prg in
Message.emit_debug "Retyping lambda calculus...";
let prg = Typing.program ~fail_on_any:false ~assume_op_types:true prg in
prg, type_ordering)
else prg, type_ordering
in
prg, type_ordering
@ -556,10 +555,7 @@ module Commands = struct
(* Additionally, we might want to check the invariants. *)
if check_invariants then (
let prg =
Shared_ast.Typing.program ~leave_unresolved:ErrorOnAny
(Program.untype prg)
in
let prg = Shared_ast.Typing.program prg in
Message.emit_debug "Checking invariants...";
if Dcalc.Invariants.check_all_invariants prg then
Message.emit_result "All invariant checks passed"

View File

@ -67,15 +67,11 @@ type 'm program = {
let type_rule decl_ctx env = function
| Definition (loc, typ, io, expr) ->
let expr' =
Typing.expr ~leave_unresolved:ErrorOnAny decl_ctx ~env ~typ expr
in
let expr' = Typing.expr decl_ctx ~env ~typ expr in
Definition (loc, typ, io, Expr.unbox expr')
| Assertion expr ->
let typ = Mark.add (Expr.pos expr) (TLit TBool) in
let expr' =
Typing.expr ~leave_unresolved:ErrorOnAny decl_ctx ~env ~typ expr
in
let expr' = Typing.expr decl_ctx ~env ~typ expr in
Assertion (Expr.unbox expr')
| Call (sc_name, ssc_name, m) ->
let pos = Expr.mark_pos m in
@ -118,10 +114,7 @@ let type_program (type m) (prg : m program) : typed program =
let program_topdefs =
TopdefName.Map.map
(fun (expr, typ) ->
( Expr.unbox
(Typing.expr prg.program_ctx ~leave_unresolved:ErrorOnAny ~env ~typ
expr),
typ ))
Expr.unbox (Typing.expr prg.program_ctx ~env ~typ expr), typ)
prg.program_topdefs
in
let program_scopes =

View File

@ -562,14 +562,17 @@ module ExprGen (C : EXPR_PARAM) = struct
Format.fprintf fmt "@[<hv 0>%a @[<hv 2>%a@]@ @]%a@ %a" punctuation "λ"
(Format.pp_print_list ~pp_sep:Format.pp_print_space
(fun fmt (x, tau) ->
punctuation fmt "(";
Format.pp_open_hvbox fmt 2;
var fmt x;
punctuation fmt ":";
Format.pp_print_space fmt ();
typ_gen None ~colors fmt tau;
Format.pp_close_box fmt ();
punctuation fmt ")"))
match tau with
| TLit TUnit, _ -> punctuation fmt "("; punctuation fmt ")"
| _ ->
punctuation fmt "(";
Format.pp_open_hvbox fmt 2;
var fmt x;
punctuation fmt ":";
Format.pp_print_space fmt ();
typ_gen None ~colors fmt tau;
Format.pp_close_box fmt ();
punctuation fmt ")"))
xs_tau punctuation "" (rhs expr) body
| EAppOp { op = (Map | Filter) as op; args = [arg1; arg2]; _ } ->
Format.fprintf fmt "@[<hv 2>%a %a@ %a@]" operator op (lhs exprc) arg1
@ -704,13 +707,20 @@ module ExprGen (C : EXPR_PARAM) = struct
~pp_sep:(fun fmt () -> Format.fprintf fmt "@\n")
(fun fmt pp_cons_name case_expr ->
match case_expr with
| EAbs { binder; _ }, _ ->
| EAbs { binder; tys; _ }, _ ->
let xs, body, bnd_ctx = Bindlib.unmbind_in bnd_ctx binder in
let expr = exprb bnd_ctx in
Format.fprintf fmt "@[<hov 2>%a %t@ %a@ %a@ %a@]" punctuation
let pp_args fmt = match tys with
| [TLit TUnit, _] -> ()
| _ ->
Format.pp_print_seq ~pp_sep:Format.pp_print_space var fmt
(Array.to_seq xs);
Format.pp_print_space fmt ()
in
Format.fprintf fmt "@[<hov 2>%a %t@ %t%a@ %a@]" punctuation
"|" pp_cons_name
(Format.pp_print_seq ~pp_sep:Format.pp_print_space var)
(Array.to_seq xs) punctuation "" (rhs expr) body
pp_args
punctuation "" (rhs expr) body
| e ->
Format.fprintf fmt "@[<hov 2>%a %t@ %a@ %a@]" punctuation "|"
pp_cons_name punctuation "" (rhs exprc) e))

View File

@ -20,7 +20,7 @@
open Catala_utils
module A = Definitions
type resolving_strategy = LeaveAny | ErrorOnAny
type flags = { fail_on_any : bool; assume_op_types : bool }
module Any =
Uid.Make
@ -54,9 +54,8 @@ and naked_typ =
| TAny of Any.t
| TClosureEnv
let rec typ_to_ast ~(leave_unresolved : resolving_strategy) (ty : unionfind_typ)
: A.typ =
let typ_to_ast = typ_to_ast ~leave_unresolved in
let rec typ_to_ast ~(flags : flags) (ty : unionfind_typ) : A.typ =
let typ_to_ast = typ_to_ast ~flags in
let ty, pos = UnionFind.get (UnionFind.find ty) in
match ty with
| TLit l -> A.TLit l, pos
@ -67,15 +66,14 @@ let rec typ_to_ast ~(leave_unresolved : resolving_strategy) (ty : unionfind_typ)
| TArrow (t1, t2) -> A.TArrow (List.map typ_to_ast t1, typ_to_ast t2), pos
| TArray t1 -> A.TArray (typ_to_ast t1), pos
| TDefault t1 -> A.TDefault (typ_to_ast t1), pos
| TAny _ -> (
match leave_unresolved with
| LeaveAny -> A.TAny, pos
| ErrorOnAny ->
| TAny _ ->
if flags.fail_on_any then
(* No polymorphism in Catala: type inference should return full types
without wildcards, and this function is used to recover the types after
typing. *)
Message.raise_spanned_error pos
"Internal error: typing at this point could not be resolved")
"Internal error: typing at this point could not be resolved"
else A.TAny, pos
| TClosureEnv -> TClosureEnv, pos
let rec ast_to_typ (ty : A.typ) : unionfind_typ =
@ -321,8 +319,39 @@ let polymorphic_op_type (op : Operator.polymorphic A.operator Mark.pos) :
in
Lazy.force ty
(* Just returns the return type of the operator, assuming the operand types are
known. Less trict, but useful on monomorphised code where the operators no
longer have their standard types *)
let polymorphic_op_return_type
ctx
e
(op : Operator.polymorphic A.operator Mark.pos)
(targs : unionfind_typ list) : unionfind_typ =
let open Operator in
let pos = Mark.get op in
let uf t = UnionFind.make (t, pos) in
let any _ = uf (TAny (Any.fresh ())) in
let return_type tf arity =
let tret = any () in
unify ctx e tf (UnionFind.make (TArrow (List.init arity any, tret), pos));
tret
in
match Mark.remove op, targs with
| Fold, [_; tau; _] -> tau
| Eq, _ -> uf (TLit TBool)
| Map, [tf; _] -> uf (TArray (return_type tf 1))
| Map2, [tf; _; _] -> uf (TArray (return_type tf 2))
| (Filter | Reduce | Concat), [_; tau] -> tau
| Log (PosRecordIfTrueBool, _), _ -> uf (TLit TBool)
| Log _, [tau] -> tau
| Length, _ -> uf (TLit TInt)
| (HandleDefault | HandleDefaultOpt), [_; _; tf] -> return_type tf 1
| ToClosureEnv, _ -> uf TClosureEnv
| FromClosureEnv, _ -> any ()
| _ -> Message.raise_spanned_error pos "Mismatched operator arguments"
let resolve_overload_ret_type
~leave_unresolved
~flags
(ctx : A.decl_ctx)
e
(op : Operator.overloaded A.operator)
@ -330,7 +359,7 @@ let resolve_overload_ret_type
let op_ty =
Operator.overload_type ctx
(Mark.add (Expr.pos e) op)
(List.map (typ_to_ast ~leave_unresolved) tys)
(List.map (typ_to_ast ~flags) tys)
in
ast_to_typ (Type.arrow_return op_ty)
@ -338,6 +367,7 @@ let resolve_overload_ret_type
module Env = struct
type 'e t = {
flags : flags;
structs : unionfind_typ A.StructField.Map.t A.StructName.Map.t;
enums : unionfind_typ A.EnumConstructor.Map.t A.EnumName.Map.t;
vars : ('e, unionfind_typ) Var.Map.t;
@ -347,10 +377,14 @@ module Env = struct
toplevel_vars : A.typ A.TopdefName.Map.t;
}
let empty (decl_ctx : A.decl_ctx) =
let empty
?(fail_on_any = true)
?(assume_op_types = false)
(decl_ctx : A.decl_ctx) =
(* We fill the environment initially with the structs and enums
declarations *)
{
flags = { fail_on_any; assume_op_types };
structs =
A.StructName.Map.map
(fun ty -> A.StructField.Map.map ast_to_typ ty)
@ -423,29 +457,28 @@ let ty : (_, unionfind_typ A.custom) A.marked -> unionfind_typ =
(** Infers the most permissive type from an expression *)
let rec typecheck_expr_bottom_up :
type a m.
leave_unresolved:resolving_strategy ->
A.decl_ctx ->
(a, m) A.gexpr Env.t ->
(a, m) A.gexpr ->
(a, unionfind_typ A.custom) A.boxed_gexpr =
fun ~leave_unresolved ctx env e ->
typecheck_expr_top_down ~leave_unresolved ctx env
fun ctx env e ->
typecheck_expr_top_down ctx env
(UnionFind.make (add_pos e (TAny (Any.fresh ()))))
e
(** Checks whether the expression can be typed with the provided type *)
and typecheck_expr_top_down :
type a m.
leave_unresolved:resolving_strategy ->
A.decl_ctx ->
(a, m) A.gexpr Env.t ->
unionfind_typ ->
(a, m) A.gexpr ->
(a, unionfind_typ A.custom) A.boxed_gexpr =
fun ~leave_unresolved ctx env tau e ->
fun ctx env tau e ->
(* Message.emit_debug "Propagating type %a for naked_expr :@.@[<hov 2>%a@]"
(format_typ ctx) tau Expr.format e; *)
let pos_e = Expr.pos e in
let flags = env.flags in
let () =
(* If there already is a type annotation on the given expr, ensure it
matches *)
@ -519,7 +552,7 @@ and typecheck_expr_top_down :
A.StructField.Map.mapi
(fun f_name f_e ->
let f_ty = A.StructField.Map.find f_name str in
typecheck_expr_top_down ~leave_unresolved ctx env f_ty f_e)
typecheck_expr_top_down ctx env f_ty f_e)
fields
in
Expr.estruct ~name ~fields mark
@ -530,8 +563,7 @@ and typecheck_expr_top_down :
| None -> TAny (Any.fresh ())
in
let e_struct' =
typecheck_expr_top_down ~leave_unresolved ctx env (unionfind t_struct)
e_struct
typecheck_expr_top_down ctx env (unionfind t_struct) e_struct
in
let name =
match UnionFind.get (ty e_struct') with
@ -598,8 +630,7 @@ and typecheck_expr_top_down :
in
let mark = mark_with_tau_and_unify fld_ty in
let e_struct' =
typecheck_expr_top_down ~leave_unresolved ctx env
(unionfind (TStruct name)) e_struct
typecheck_expr_top_down ctx env (unionfind (TStruct name)) e_struct
in
Expr.estructaccess ~e:e_struct' ~field ~name mark
| A.EInj { name; cons; e = e_enum }
@ -607,23 +638,20 @@ and typecheck_expr_top_down :
if Definitions.EnumConstructor.equal cons Expr.some_constr then
let cell_type = unionfind (TAny (Any.fresh ())) in
let mark = mark_with_tau_and_unify (unionfind (TOption cell_type)) in
let e_enum' =
typecheck_expr_top_down ~leave_unresolved ctx env cell_type e_enum
in
let e_enum' = typecheck_expr_top_down ctx env cell_type e_enum in
Expr.einj ~name ~cons ~e:e_enum' mark
else
(* None constructor *)
let cell_type = unionfind (TAny (Any.fresh ())) in
let mark = mark_with_tau_and_unify (unionfind (TOption cell_type)) in
let e_enum' =
typecheck_expr_top_down ~leave_unresolved ctx env
(unionfind (TLit TUnit)) e_enum
typecheck_expr_top_down ctx env (unionfind (TLit TUnit)) e_enum
in
Expr.einj ~name ~cons ~e:e_enum' mark
| A.EInj { name; cons; e = e_enum } ->
let mark = mark_with_tau_and_unify (unionfind (TEnum name)) in
let e_enum' =
typecheck_expr_top_down ~leave_unresolved ctx env
typecheck_expr_top_down ctx env
(A.EnumConstructor.Map.find cons (A.EnumName.Map.find name env.enums))
e_enum
in
@ -640,14 +668,14 @@ and typecheck_expr_top_down :
in
let t_ret = unionfind ~pos:e (TAny (Any.fresh ())) in
let mark = mark_with_tau_and_unify t_ret in
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env t_arg e1 in
let e1' = typecheck_expr_top_down ctx env t_arg e1 in
let cases =
A.EnumConstructor.Map.merge
(fun _ e e_ty ->
match e, e_ty with
| Some e, Some e_ty ->
Some
(typecheck_expr_top_down ~leave_unresolved ctx env
(typecheck_expr_top_down ctx env
(unionfind ~pos:e (TArrow ([e_ty], t_ret)))
e)
| _ -> assert false)
@ -658,10 +686,7 @@ and typecheck_expr_top_down :
let cases_ty = A.EnumName.Map.find name ctx.A.ctx_enums in
let t_ret = unionfind ~pos:e1 (TAny (Any.fresh ())) in
let mark = mark_with_tau_and_unify t_ret in
let e1' =
typecheck_expr_top_down ~leave_unresolved ctx env (unionfind (TEnum name))
e1
in
let e1' = typecheck_expr_top_down ctx env (unionfind (TEnum name)) e1 in
let cases =
A.EnumConstructor.Map.mapi
(fun c_name e ->
@ -670,7 +695,7 @@ and typecheck_expr_top_down :
there is a change to allow for multiple arguments, it might be
easier to use tuples directly. *)
let e_ty = unionfind ~pos:e (TArrow ([ast_to_typ c_ty], t_ret)) in
typecheck_expr_top_down ~leave_unresolved ctx env e_ty e)
typecheck_expr_top_down ctx env e_ty e)
cases
in
Expr.ematch ~e:e1' ~name ~cases mark
@ -683,17 +708,15 @@ and typecheck_expr_top_down :
let args' =
A.ScopeVar.Map.mapi
(fun name ->
typecheck_expr_top_down ~leave_unresolved ctx env
typecheck_expr_top_down ctx env
(ast_to_typ (A.ScopeVar.Map.find name vars)))
args
in
Expr.escopecall ~scope ~args:args' mark
| A.ERaise ex -> Expr.eraise ex context_mark
| A.ECatch { body; exn; handler } ->
let body' = typecheck_expr_top_down ~leave_unresolved ctx env tau body in
let handler' =
typecheck_expr_top_down ~leave_unresolved ctx env tau handler
in
let body' = typecheck_expr_top_down ctx env tau body in
let handler' = typecheck_expr_top_down ctx env tau handler in
Expr.ecatch body' exn handler' context_mark
| A.EVar v ->
let tau' =
@ -732,9 +755,7 @@ and typecheck_expr_top_down :
| A.ETuple es ->
let tys = List.map (fun _ -> unionfind (TAny (Any.fresh ()))) es in
let mark = mark_with_tau_and_unify (unionfind (TTuple tys)) in
let es' =
List.map2 (typecheck_expr_top_down ~leave_unresolved ctx env) tys es
in
let es' = List.map2 (typecheck_expr_top_down ctx env) tys es in
Expr.etuple es' mark
| A.ETupleAccess { e = e1; index; size } ->
if index >= size then
@ -745,11 +766,7 @@ and typecheck_expr_top_down :
(List.init size (fun n ->
if n = index then tau else unionfind ~pos:e1 (TAny (Any.fresh ()))))
in
let e1' =
typecheck_expr_top_down ~leave_unresolved ctx env
(unionfind ~pos:e1 tuple_ty)
e1
in
let e1' = typecheck_expr_top_down ctx env (unionfind ~pos:e1 tuple_ty) e1 in
Expr.etupleaccess ~e:e1' ~index ~size context_mark
| A.EAbs { binder; tys = t_args } ->
if Bindlib.mbinder_arity binder <> List.length t_args then
@ -769,11 +786,9 @@ and typecheck_expr_top_down :
(fun env x tau_arg -> Env.add x tau_arg env)
env (Array.to_list xs) tau_args
in
let body' =
typecheck_expr_top_down ~leave_unresolved ctx env t_ret body
in
let body' = typecheck_expr_top_down ctx env t_ret body in
let binder' = Bindlib.bind_mvar xs' (Expr.Box.lift body') in
Expr.eabs binder' (List.map (typ_to_ast ~leave_unresolved) tau_args) mark
Expr.eabs binder' (List.map (typ_to_ast ~flags) tau_args) mark
| A.EApp { f = e1; args; tys } ->
(* Here we type the arguments first (in order), to ensure we know the types
of the arguments if [f] is [EAbs] before disambiguation. This is also the
@ -783,9 +798,7 @@ and typecheck_expr_top_down :
| [] -> List.map (fun _ -> unionfind (TAny (Any.fresh ()))) args
| tys -> List.map ast_to_typ tys
in
let args' =
List.map2 (typecheck_expr_top_down ~leave_unresolved ctx env) t_args args
in
let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
let t_args =
match t_args, tys with
| [t], [] -> (
@ -805,9 +818,9 @@ and typecheck_expr_top_down :
t_args
in
let t_func = unionfind ~pos:e1 (TArrow (t_args, tau)) in
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env t_func e1 in
let e1' = typecheck_expr_top_down ctx env t_func e1 in
Expr.eapp ~f:e1' ~args:args'
~tys:(List.map (typ_to_ast ~leave_unresolved) t_args)
~tys:(List.map (typ_to_ast ~flags) t_args)
context_mark
| A.EAppOp { op; tys; args } ->
let t_args = List.map ast_to_typ tys in
@ -818,87 +831,73 @@ and typecheck_expr_top_down :
(* Type the operator first, then right-to-left: polymorphic operators
are required to allow the resolution of all type variables this
way *)
unify ctx e (polymorphic_op_type (Mark.add pos_e op)) t_func;
if not env.flags.assume_op_types then
unify ctx e (polymorphic_op_type (Mark.add pos_e op)) t_func
else
unify ctx e
(polymorphic_op_return_type ctx e (Mark.add pos_e op) t_args)
tau;
List.rev_map2
(typecheck_expr_top_down ~leave_unresolved ctx env)
(typecheck_expr_top_down ctx env)
(List.rev t_args) (List.rev args))
~overloaded:(fun op ->
(* Typing the arguments first is required to resolve the operator *)
let args' =
List.map2
(typecheck_expr_top_down ~leave_unresolved ctx env)
t_args args
in
unify ctx e tau
(resolve_overload_ret_type ~leave_unresolved ctx e op t_args);
let args' = List.map2 (typecheck_expr_top_down ctx env) t_args args in
unify ctx e tau (resolve_overload_ret_type ~flags ctx e op t_args);
args')
~monomorphic:(fun op ->
(* Here it doesn't matter but may affect the error messages *)
unify ctx e
(ast_to_typ (Operator.monomorphic_type (Mark.add pos_e op)))
t_func;
List.map2
(typecheck_expr_top_down ~leave_unresolved ctx env)
t_args args)
List.map2 (typecheck_expr_top_down ctx env) t_args args)
~resolved:(fun op ->
(* This case should not fail *)
unify ctx e
(ast_to_typ (Operator.resolved_type (Mark.add pos_e op)))
t_func;
List.map2
(typecheck_expr_top_down ~leave_unresolved ctx env)
t_args args)
List.map2 (typecheck_expr_top_down ctx env) t_args args)
in
(* All operator applications are monomorphised at this point *)
let tys = List.map (typ_to_ast ~leave_unresolved) t_args in
let tys = List.map (typ_to_ast ~flags) t_args in
Expr.eappop ~op ~args ~tys context_mark
| A.EDefault { excepts; just; cons } ->
let cons' = typecheck_expr_top_down ~leave_unresolved ctx env tau cons in
let cons' = typecheck_expr_top_down ctx env tau cons in
let just' =
typecheck_expr_top_down ~leave_unresolved ctx env
(unionfind ~pos:just (TLit TBool))
just
in
let excepts' =
List.map (typecheck_expr_top_down ~leave_unresolved ctx env tau) excepts
typecheck_expr_top_down ctx env (unionfind ~pos:just (TLit TBool)) just
in
let excepts' = List.map (typecheck_expr_top_down ctx env tau) excepts in
Expr.edefault ~excepts:excepts' ~just:just' ~cons:cons' context_mark
| A.EPureDefault e1 ->
let inner_ty = unionfind ~pos:e1 (TAny (Any.fresh ())) in
let mark =
mark_with_tau_and_unify (unionfind ~pos:e1 (TDefault inner_ty))
in
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env inner_ty e1 in
let e1' = typecheck_expr_top_down ctx env inner_ty e1 in
Expr.epuredefault e1' mark
| A.EIfThenElse { cond; etrue = et; efalse = ef } ->
let et' = typecheck_expr_top_down ~leave_unresolved ctx env tau et in
let ef' = typecheck_expr_top_down ~leave_unresolved ctx env tau ef in
let et' = typecheck_expr_top_down ctx env tau et in
let ef' = typecheck_expr_top_down ctx env tau ef in
let cond' =
typecheck_expr_top_down ~leave_unresolved ctx env
(unionfind ~pos:cond (TLit TBool))
cond
typecheck_expr_top_down ctx env (unionfind ~pos:cond (TLit TBool)) cond
in
Expr.eifthenelse cond' et' ef' context_mark
| A.EAssert e1 ->
let mark = mark_with_tau_and_unify (unionfind (TLit TUnit)) in
let e1' =
typecheck_expr_top_down ~leave_unresolved ctx env
(unionfind ~pos:e1 (TLit TBool))
e1
typecheck_expr_top_down ctx env (unionfind ~pos:e1 (TLit TBool)) e1
in
Expr.eassert e1' mark
| A.EEmptyError ->
Expr.eemptyerror (ty_mark (TDefault (unionfind (TAny (Any.fresh ())))))
| A.EErrorOnEmpty e1 ->
let tau' = unionfind (TDefault tau) in
let e1' = typecheck_expr_top_down ~leave_unresolved ctx env tau' e1 in
let e1' = typecheck_expr_top_down ctx env tau' e1 in
Expr.eerroronempty e1' context_mark
| A.EArray es ->
let cell_type = unionfind (TAny (Any.fresh ())) in
let mark = mark_with_tau_and_unify (unionfind (TArray cell_type)) in
let es' =
List.map (typecheck_expr_top_down ~leave_unresolved ctx env cell_type) es
in
let es' = List.map (typecheck_expr_top_down ctx env cell_type) es in
Expr.earray es' mark
| A.ECustom { obj; targs; tret } ->
let mark =
@ -920,42 +919,36 @@ let wrap_expr ctx f e =
(** {1 API} *)
let get_ty_mark ~leave_unresolved (A.Custom { A.custom = uf; pos }) =
A.Typed { ty = typ_to_ast ~leave_unresolved uf; pos }
let get_ty_mark ~flags (A.Custom { A.custom = uf; pos }) =
A.Typed { ty = typ_to_ast ~flags uf; pos }
let expr_raw
(type a)
~(leave_unresolved : resolving_strategy)
(ctx : A.decl_ctx)
?(env = Env.empty ctx)
?(typ : A.typ option)
(e : (a, 'm) A.gexpr) : (a, unionfind_typ A.custom) A.gexpr =
let fty =
match typ with
| None -> typecheck_expr_bottom_up ~leave_unresolved ctx env
| Some typ ->
typecheck_expr_top_down ~leave_unresolved ctx env (ast_to_typ typ)
| None -> typecheck_expr_bottom_up ctx env
| Some typ -> typecheck_expr_top_down ctx env (ast_to_typ typ)
in
wrap_expr ctx fty e
let check_expr ~leave_unresolved ctx ?env ?typ e =
let check_expr ctx ?env ?typ e =
Expr.map_marks
~f:(fun (Custom { pos; _ }) -> A.Untyped { pos })
(expr_raw ctx ~leave_unresolved ?env ?typ e)
(expr_raw ctx ?env ?typ e)
(* Infer the type of an expression *)
let expr ~leave_unresolved ctx ?env ?typ e =
Expr.map_marks
~f:(get_ty_mark ~leave_unresolved)
(expr_raw ~leave_unresolved ctx ?env ?typ e)
let expr ctx ?(env = Env.empty ctx) ?typ e =
Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) (expr_raw ctx ~env ?typ e)
let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
let rec scope_body_expr ctx env ty_out body_expr =
match body_expr with
| A.Result e ->
let e' =
wrap_expr ctx (typecheck_expr_top_down ~leave_unresolved ctx env ty_out) e
in
let e' = Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e' in
let e' = wrap_expr ctx (typecheck_expr_top_down ctx env ty_out) e in
let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in
Bindlib.box_apply (fun e -> A.Result e) (Expr.Box.lift e')
| A.ScopeLet
{
@ -966,9 +959,7 @@ let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
scope_let_pos;
} ->
let ty_e = ast_to_typ scope_let_typ in
let e =
wrap_expr ctx (typecheck_expr_bottom_up ~leave_unresolved ctx env) e0
in
let e = wrap_expr ctx (typecheck_expr_bottom_up ctx env) e0 in
wrap ctx (fun t -> unify ctx e0 (ty e) t) ty_e;
(* We could use [typecheck_expr_top_down] rather than this manual
unification, but we get better messages with this order of the [unify]
@ -976,7 +967,7 @@ let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
inferred. *)
let var, next = Bindlib.unbind scope_let_next in
let env = Env.add var ty_e env in
let next = scope_body_expr ~leave_unresolved ctx env ty_out next in
let next = scope_body_expr ctx env ty_out next in
let scope_let_next = Bindlib.bind_var (Var.translate var) next in
Bindlib.box_apply2
(fun scope_let_expr scope_let_next ->
@ -985,16 +976,16 @@ let rec scope_body_expr ~leave_unresolved ctx env ty_out body_expr =
scope_let_kind;
scope_let_typ =
(match Mark.remove scope_let_typ with
| TAny -> typ_to_ast ~leave_unresolved (ty e)
| TAny -> typ_to_ast ~flags:env.flags (ty e)
| _ -> scope_let_typ);
scope_let_expr;
scope_let_next;
scope_let_pos;
})
(Expr.Box.lift (Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e))
(Expr.Box.lift (Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e))
scope_let_next
let scope_body ~leave_unresolved ctx env body =
let scope_body ctx env body =
let get_pos struct_name = Mark.get (A.StructName.get_info struct_name) in
let struct_ty struct_name =
UnionFind.make (Mark.add (get_pos struct_name) (TStruct struct_name))
@ -1003,7 +994,7 @@ let scope_body ~leave_unresolved ctx env body =
let ty_out = struct_ty body.A.scope_body_output_struct in
let var, e = Bindlib.unbind body.A.scope_body_expr in
let env = Env.add var ty_in env in
let e' = scope_body_expr ~leave_unresolved ctx env ty_out e in
let e' = scope_body_expr ctx env ty_out e in
( Bindlib.box_apply
(fun scope_body_expr -> { body with scope_body_expr })
(Bindlib.bind_var (Var.translate var) e'),
@ -1012,35 +1003,33 @@ let scope_body ~leave_unresolved ctx env body =
(get_pos body.A.scope_body_output_struct)
(TArrow ([ty_in], ty_out))) )
let rec scopes ~leave_unresolved ctx env = function
let rec scopes ctx env = function
| A.Nil -> Bindlib.box A.Nil, env
| A.Cons (item, next_bind) ->
let var, next = Bindlib.unbind next_bind in
let env, def =
match item with
| A.ScopeDef (name, body) ->
let body_e, ty_scope = scope_body ~leave_unresolved ctx env body in
let body_e, ty_scope = scope_body ctx env body in
( Env.add var ty_scope env,
Bindlib.box_apply (fun body -> A.ScopeDef (name, body)) body_e )
| A.Topdef (name, typ, e) ->
let e' = expr_raw ~leave_unresolved ctx ~env ~typ e in
let e' = expr_raw ctx ~env ~typ e in
let (A.Custom { custom = uf; _ }) = Mark.get e' in
let e' = Expr.map_marks ~f:(get_ty_mark ~leave_unresolved) e' in
let e' = Expr.map_marks ~f:(get_ty_mark ~flags:env.flags) e' in
( Env.add var uf env,
Bindlib.box_apply
(fun e -> A.Topdef (name, Expr.ty e', e))
(Expr.Box.lift e') )
in
let next', env = scopes ~leave_unresolved ctx env next in
let next', env = scopes ctx env next in
let next_bind' = Bindlib.bind_var (Var.translate var) next' in
( Bindlib.box_apply2 (fun item next -> A.Cons (item, next)) def next_bind',
env )
let program ~leave_unresolved prg =
let code_items, new_env =
scopes ~leave_unresolved prg.A.decl_ctx (Env.empty prg.A.decl_ctx)
prg.A.code_items
in
let program ?fail_on_any ?assume_op_types prg =
let env = Env.empty ?fail_on_any ?assume_op_types prg.A.decl_ctx in
let code_items, new_env = scopes prg.A.decl_ctx env prg.A.code_items in
{
A.lang = prg.lang;
A.module_name = prg.A.module_name;
@ -1055,7 +1044,7 @@ let program ~leave_unresolved prg =
(fun f_name (t : A.typ) ->
match Mark.remove t with
| TAny ->
typ_to_ast ~leave_unresolved
typ_to_ast ~flags:env.flags
(A.StructField.Map.find f_name
(A.StructName.Map.find s_name new_env.structs))
| _ -> t)
@ -1068,7 +1057,7 @@ let program ~leave_unresolved prg =
(fun cons_name (t : A.typ) ->
match Mark.remove t with
| TAny ->
typ_to_ast ~leave_unresolved
typ_to_ast ~flags:env.flags
(A.EnumConstructor.Map.find cons_name
(A.EnumName.Map.find e_name new_env.enums))
| _ -> t)

View File

@ -22,7 +22,17 @@ open Definitions
module Env : sig
type 'e t
val empty : decl_ctx -> 'e t
val empty : ?fail_on_any:bool -> ?assume_op_types:bool -> decl_ctx -> 'e t
(** The [~fail_on_any] labeled parameter controls the behavior of the typer in
the case where polymorphic expressions are still found after typing: if
[false], it allows them (giving them [TAny] and losing typing
information); if set to [true] (the default), it aborts.
The [~assume_op_types] flag (default false) ignores the expected built-in
types of polymorphic operators, and will assume correct the type
information included in [EAppOp] nodes. This is useful after
monomorphisation, which changes the expected types for these operators. *)
val add_var : 'e Var.t -> typ -> 'e t -> 'e t
val add_toplevel_var : TopdefName.t -> typ -> 'e t -> 'e t
val add_scope_var : ScopeVar.t -> typ -> 'e t -> 'e t
@ -40,15 +50,7 @@ module Env : sig
(** For debug purposes *)
end
(** In the following functions, the [~leave_unresolved] labeled parameter
controls the behavior of the typer in the case where polymorphic expressions
are still found after typing: if set to [LeaveAny], it allows them (giving
them [TAny] and losing typing information); if set to [ErrorOnAny], it
aborts. *)
type resolving_strategy = LeaveAny | ErrorOnAny
val expr :
leave_unresolved:resolving_strategy ->
decl_ctx ->
?env:'e Env.t ->
?typ:typ ->
@ -75,11 +77,10 @@ val expr :
application, taking de-tuplification into account.
- [TAny] appearing within nodes are refined to more precise types, e.g. on
`EAbs` nodes (but be careful with this, it may only work for specific
structures of generated code ; [~leave_unresolved:false] checks that it
didn't cause problems) *)
structures of generated code ; having [~fail_on_any:true] set in the
environment (this is the default) checks that it didn't cause problems) *)
val check_expr :
leave_unresolved:resolving_strategy ->
decl_ctx ->
?env:'e Env.t ->
?typ:typ ->
@ -91,7 +92,8 @@ val check_expr :
information, e.g. any [TAny] appearing in the AST is replaced) *)
val program :
leave_unresolved:resolving_strategy ->
?fail_on_any:bool ->
?assume_op_types:bool ->
('a, 'm) gexpr program ->
('a, typed) gexpr program
(** Typing on whole programs (as defined in Shared_ast.program, i.e. for the

View File

@ -29,5 +29,5 @@ scope Baz:
```catala-test { id = "c" }
$ catala c -o -
$ catala c
```

View File

@ -53,8 +53,8 @@ let scope S (S_in: S_in {x_in: list of integer}): S {y: integer} =
[
handle_default_opt
[]
(λ (_: unit) → true)
(λ (_: unit) →
(λ () → true)
(λ () →
ESome
(let weights : list of (integer * integer) =
map (λ (potential_max: integer) →
@ -72,10 +72,10 @@ let scope S (S_in: S_in {x_in: list of integer}): S {y: integer} =
potential_max1)
weights).0)
]
(λ (_: unit) → false)
(λ (_: unit) → ENone ()))
(λ () → false)
(λ () → ENone ()))
with
| ENone _ → raise NoValueProvided
| ENone → raise NoValueProvided
| ESome arg → arg
in
return { S y = y; }

View File

@ -122,12 +122,9 @@ let scope Foo
in
let set b : bool =
match
(handle_default_opt
[b.0 b.1 ()]
(λ (_: unit) → true)
(λ (_: unit) → ESome true))
(handle_default_opt [b.0 b.1 ()] (λ () → true) (λ () → ESome true))
with
| ENone _ → raise NoValueProvided
| ENone → raise NoValueProvided
| ESome arg → arg
in
let set r :

View File

@ -31,7 +31,7 @@ $ catala Typecheck --check-invariants
```catala-test-inline
$ catala Dcalc -s B
let scope B (B_in: B_in): B =
let sub_set a.a : unit → ⟨integer⟩ = λ (_: unit) → ∅ in
let sub_set a.a : unit → ⟨integer⟩ = λ () → ∅ in
let sub_set a.b : integer = error_empty ⟨ ⟨true ⊢ ⟨2⟩⟩ | false ⊢ ∅ ⟩ in
let call result : A {c: integer} = A { A_in a_in = a.a; b_in = a.b; } in
let sub_get a.c : integer = result.c in

View File

@ -41,5 +41,5 @@ $ catala typecheck --disable-warnings
```
```catala-test { id="ml" }
$ catala ocaml --disable-warnings -o -
$ catala ocaml --disable-warnings
```

View File

@ -22,12 +22,9 @@ let scope Foo (Foo_in: Foo_in): Foo {bar: integer} =
let set bar : integer =
try
handle_default
[
λ (_: unit) →
handle_default [] (λ (_1: unit) → true) (λ (_1: unit) → 0)
]
(λ (_: unit) → false)
(λ (_: unit) → raise EmptyError)
[λ () → handle_default [] (λ () → true) (λ () → 0)]
(λ () → false)
(λ () → raise EmptyError)
with EmptyError -> raise NoValueProvided
in
return { Foo bar = bar; }