python3Packages.cupy: fix (use older cutensor)

This commit is contained in:
Someone Serge 2023-11-27 17:36:57 +00:00
parent 0dc161b2f8
commit ee108108fc
No known key found for this signature in database
GPG Key ID: 7B0E3B1390D61DA4
2 changed files with 39 additions and 10 deletions

View File

@ -11,11 +11,34 @@
, cudaPackages
, addOpenGLRunpath
, pythonOlder
, symlinkJoin
}:
let
inherit (cudaPackages) cudatoolkit cudnn cutensor nccl;
in buildPythonPackage rec {
inherit (cudaPackages) cudnn cutensor nccl;
cudatoolkit-joined = symlinkJoin {
name = "cudatoolkit-joined-${cudaPackages.cudaVersion}";
paths = with cudaPackages; [
cuda_cccl # <nv/target>
cuda_cccl.dev
cuda_cudart
cuda_nvcc.dev # <crt/host_defines.h>
cuda_nvprof
cuda_nvrtc
cuda_nvtx
cuda_profiler_api
libcublas
libcufft
libcurand
libcusolver
libcusparse
# Missing:
# cusparselt
];
};
in
buildPythonPackage rec {
pname = "cupy";
version = "12.2.0";
@ -32,27 +55,32 @@ in buildPythonPackage rec {
# very short builds and a few extremely long ones, so setting both ends up
# working nicely in practice.
preConfigure = ''
export CUDA_PATH=${cudatoolkit}
export CUPY_NUM_BUILD_JOBS="$NIX_BUILD_CORES"
export CUPY_NUM_NVCC_THREADS="$NIX_BUILD_CORES"
'';
nativeBuildInputs = [
setuptools
wheel
addOpenGLRunpath
cython
cudaPackages.cuda_nvcc
];
LDFLAGS = "-L${cudatoolkit}/lib/stubs";
propagatedBuildInputs = [
cudatoolkit
buildInputs = [
cudatoolkit-joined
cudnn
cutensor
nccl
];
NVCC = "${lib.getExe cudaPackages.cuda_nvcc}"; # FIXME: splicing/buildPackages
CUDA_PATH = "${cudatoolkit-joined}";
LDFLAGS = "-L${cudaPackages.cuda_cudart}/lib/stubs";
propagatedBuildInputs = [
fastrlock
numpy
setuptools
wheel
];
nativeCheckInputs = [

View File

@ -2467,7 +2467,8 @@ self: super: with self; {
cufflinks = callPackage ../development/python-modules/cufflinks { };
cupy = callPackage ../development/python-modules/cupy { };
# cupy 12.2.0 possibly incompatible with cutensor 2.0 that comes with cudaPackages_12
cupy = callPackage ../development/python-modules/cupy { cudaPackages = pkgs.cudaPackages_11; };
curio = callPackage ../development/python-modules/curio { };