mirror of
https://github.com/ilyakooo0/nixpkgs.git
synced 2024-10-12 23:48:25 +03:00
python311Packages.objax: patch deprecated device_buffers code
This commit is contained in:
parent
24e02939e9
commit
0ab7fd77be
@ -1,7 +1,6 @@
|
||||
{ lib
|
||||
, buildPythonPackage
|
||||
, fetchFromGitHub
|
||||
, fetchpatch
|
||||
, jax
|
||||
, jaxlib
|
||||
, keras
|
||||
@ -30,7 +29,12 @@ buildPythonPackage rec {
|
||||
hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
patches = [
|
||||
# Issue reported upstream: https://github.com/google/objax/issues/270
|
||||
./replace-deprecated-device_buffers.patch
|
||||
];
|
||||
|
||||
build-system = [
|
||||
setuptools
|
||||
];
|
||||
|
||||
@ -40,7 +44,7 @@ buildPythonPackage rec {
|
||||
jaxlib
|
||||
];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
dependencies = [
|
||||
jax
|
||||
numpy
|
||||
parameterized
|
||||
|
@ -0,0 +1,14 @@
|
||||
diff --git a/objax/util/util.py b/objax/util/util.py
|
||||
index c31a356..344cf9a 100644
|
||||
--- a/objax/util/util.py
|
||||
+++ b/objax/util/util.py
|
||||
@@ -117,7 +117,8 @@ def get_local_devices():
|
||||
if _local_devices is None:
|
||||
x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32)
|
||||
sharded_x = map_to_device(x)
|
||||
- _local_devices = [b.device() for b in sharded_x.device_buffers]
|
||||
+ device_buffers = [buf.data for buf in sharded_x.addressable_shards]
|
||||
+ _local_devices = [list(b.devices())[0] for b in device_buffers]
|
||||
return _local_devices
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user