python311Packages.objax: patch deprecated device_buffers code

This commit is contained in:
Gaetan Lepage 2024-05-10 11:32:23 +02:00
parent 24e02939e9
commit 0ab7fd77be
2 changed files with 21 additions and 3 deletions

View File

@ -1,7 +1,6 @@
{ lib
, buildPythonPackage
, fetchFromGitHub
, fetchpatch
, jax
, jaxlib
, keras
@ -30,7 +29,12 @@ buildPythonPackage rec {
hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk=";
};
nativeBuildInputs = [
patches = [
# Issue reported upstream: https://github.com/google/objax/issues/270
./replace-deprecated-device_buffers.patch
];
build-system = [
setuptools
];
@ -40,7 +44,7 @@ buildPythonPackage rec {
jaxlib
];
propagatedBuildInputs = [
dependencies = [
jax
numpy
parameterized

View File

@ -0,0 +1,14 @@
diff --git a/objax/util/util.py b/objax/util/util.py
index c31a356..344cf9a 100644
--- a/objax/util/util.py
+++ b/objax/util/util.py
@@ -117,7 +117,8 @@ def get_local_devices():
if _local_devices is None:
x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32)
sharded_x = map_to_device(x)
- _local_devices = [b.device() for b in sharded_x.device_buffers]
+ device_buffers = [buf.data for buf in sharded_x.addressable_shards]
+ _local_devices = [list(b.devices())[0] for b in device_buffers]
return _local_devices