diff --git a/pkgs/development/python-modules/dm-haiku/default.nix b/pkgs/development/python-modules/dm-haiku/default.nix index 5468776c72d1..14ceebf0a19b 100644 --- a/pkgs/development/python-modules/dm-haiku/default.nix +++ b/pkgs/development/python-modules/dm-haiku/default.nix @@ -4,6 +4,7 @@ , dill , dm-tree , fetchFromGitHub +, jaxlib , jmp , lib , pytestCheckHook @@ -31,6 +32,7 @@ buildPythonPackage rec { chex cloudpickle dm-tree + jaxlib pytestCheckHook tensorflow ]; diff --git a/pkgs/development/python-modules/elegy/default.nix b/pkgs/development/python-modules/elegy/default.nix index ec968dce8ec3..5b0cb293c0a3 100644 --- a/pkgs/development/python-modules/elegy/default.nix +++ b/pkgs/development/python-modules/elegy/default.nix @@ -4,6 +4,7 @@ , deepmerge , dm-haiku , fetchFromGitHub +, jaxlib , lib , poetry , pytestCheckHook @@ -35,6 +36,8 @@ buildPythonPackage rec { poetry ]; + buildInputs = [ jaxlib ]; + propagatedBuildInputs = [ cloudpickle deepdish diff --git a/pkgs/development/python-modules/flax/default.nix b/pkgs/development/python-modules/flax/default.nix index b8479c0f73ca..b5a22858b57a 100644 --- a/pkgs/development/python-modules/flax/default.nix +++ b/pkgs/development/python-modules/flax/default.nix @@ -1,5 +1,6 @@ { buildPythonPackage , fetchFromGitHub +, jaxlib , keras , lib , matplotlib @@ -21,6 +22,8 @@ buildPythonPackage rec { sha256 = "0zvq0vl88hiwmss49bnm7gdmndr1dfza2bcs1fj88a9r7w9dmlsr"; }; + buildInputs = [ jaxlib ]; + propagatedBuildInputs = [ matplotlib msgpack diff --git a/pkgs/development/python-modules/jmp/default.nix b/pkgs/development/python-modules/jmp/default.nix index dc096b93ae46..09c41a7ededd 100644 --- a/pkgs/development/python-modules/jmp/default.nix +++ b/pkgs/development/python-modules/jmp/default.nix @@ -19,10 +19,9 @@ buildPythonPackage rec { sha256 = "0hh4cmp93wjyidj48gh07vhx2kjvpwd23xvy79bsjn5qaaf6q4cm"; }; - # Wheel requires only `numpy`, but the import needs both `jax` and `jaxlib`. + # Wheel requires only `numpy`, but the import needs `jax`. propagatedBuildInputs = [ jax - jaxlib ]; pythonImportsCheck = [ @@ -30,6 +29,7 @@ buildPythonPackage rec { ]; checkInputs = [ + jaxlib pytestCheckHook ]; diff --git a/pkgs/development/python-modules/optax/default.nix b/pkgs/development/python-modules/optax/default.nix index bf0383fa1530..6a3b6a9d3e67 100644 --- a/pkgs/development/python-modules/optax/default.nix +++ b/pkgs/development/python-modules/optax/default.nix @@ -22,10 +22,11 @@ buildPythonPackage rec { sha256 = "1q8cxc42a5xais2ll1l238cnn3l7w28savhgiz0lg01ilz2ysbli"; }; + buildInputs = [ jaxlib ]; + propagatedBuildInputs = [ absl-py chex - jaxlib numpy ]; diff --git a/pkgs/development/python-modules/treeo/default.nix b/pkgs/development/python-modules/treeo/default.nix index 3629b47e8a29..4eac9ddeae12 100644 --- a/pkgs/development/python-modules/treeo/default.nix +++ b/pkgs/development/python-modules/treeo/default.nix @@ -22,12 +22,12 @@ buildPythonPackage rec { poetry-core ]; - # These deps are not needed for the wheel, but required during the import. + # jax is not declared in the dependencies, but is necessary. propagatedBuildInputs = [ jax - jaxlib ]; + checkInputs = [ jaxlib ]; pythonImportsCheck = [ "treeo" ]; diff --git a/pkgs/development/python-modules/treex/default.nix b/pkgs/development/python-modules/treex/default.nix index 1f4a55416d56..0b5ad0c89839 100644 --- a/pkgs/development/python-modules/treex/default.nix +++ b/pkgs/development/python-modules/treex/default.nix @@ -5,6 +5,7 @@ , fetchFromGitHub , flax , hypothesis +, jaxlib , keras , lib , poetry-core @@ -38,6 +39,8 @@ buildPythonPackage rec { poetry-core ]; + buildInputs = [ jaxlib ]; + propagatedBuildInputs = [ einops flax