Merge pull request #266239 from GaetanLepage/flax

python311Packages.flax: dependencies and tests check up
This commit is contained in:
Nick Cao 2023-11-08 08:31:42 -05:00 committed by GitHub
commit e0f8c84d82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,14 +4,17 @@
, jaxlib
, pythonRelaxDepsHook
, setuptools-scm
, cloudpickle
, jax
, matplotlib
, msgpack
, numpy
, optax
, pyyaml
, rich
, tensorstore
, typing-extensions
, matplotlib
, cloudpickle
, einops
, keras
, pytest-xdist
, pytestCheckHook
@ -37,24 +40,27 @@ buildPythonPackage rec {
];
propagatedBuildInputs = [
cloudpickle
jax
matplotlib
msgpack
numpy
optax
pyyaml
rich
tensorstore
typing-extensions
];
# See https://github.com/google/flax/pull/2882.
pythonRemoveDeps = [ "orbax" ];
passthru.optional-dependencies = {
all = [ matplotlib ];
};
pythonImportsCheck = [
"flax"
];
nativeCheckInputs = [
cloudpickle
einops
keras
pytest-xdist
pytestCheckHook
@ -85,22 +91,6 @@ buildPythonPackage rec {
"tests/checkpoints_test.py"
];
disabledTests = [
# See https://github.com/google/flax/issues/2554.
"test_async_save_checkpoints"
"test_jax_array0"
"test_jax_array1"
"test_keep0"
"test_keep1"
"test_optimized_lstm_cell_matches_regular"
"test_overwrite_checkpoints"
"test_save_restore_checkpoints_target_empty"
"test_save_restore_checkpoints_target_none"
"test_save_restore_checkpoints_target_singular"
"test_save_restore_checkpoints_w_float_steps"
"test_save_restore_checkpoints"
];
meta = with lib; {
description = "Neural network library for JAX";
homepage = "https://github.com/google/flax";