python3Packages.torch: enable cuDNN & NCCL only if available

This commit is contained in:
Connor Baker 2023-12-01 17:00:22 +00:00
parent 5a54140dff
commit 5bf016e1e9

View File

@ -56,10 +56,7 @@
let
inherit (lib) attrsets lists strings trivial;
inherit (cudaPackages) cudaFlags cudnn;
# Some packages are not available on all platforms
nccl = cudaPackages.nccl or null;
inherit (cudaPackages) cudaFlags cudnn nccl;
setBool = v: if v then "1" else "0";
@ -212,10 +209,11 @@ in buildPythonPackage rec {
# For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
preConfigure = lib.optionalString cudaSupport ''
export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
export CUDNN_LIB_DIR=${cudnn.lib}/lib
export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include
export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib
'' + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
export CUDNN_LIB_DIR=${cudnn.lib}/lib
'' + lib.optionalString rocmSupport ''
export ROCM_PATH=${rocmtoolkit_joined}
export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
@ -273,7 +271,7 @@ in buildPythonPackage rec {
PYTORCH_BUILD_VERSION = version;
PYTORCH_BUILD_NUMBER = 0;
USE_NCCL = setBool (nccl != null);
USE_NCCL = setBool (cudaPackages ? nccl);
USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL
USE_STATIC_NCCL = setBool useSystemNccl;
@ -348,8 +346,6 @@ in buildPythonPackage rec {
cuda_nvrtc.lib
cuda_nvtx.dev
cuda_nvtx.lib # -llibNVToolsExt
cudnn.dev
cudnn.lib
libcublas.dev
libcublas.lib
libcufft.dev
@ -360,7 +356,10 @@ in buildPythonPackage rec {
libcusolver.lib
libcusparse.dev
libcusparse.lib
] ++ lists.optionals (nccl != null) [
] ++ lists.optionals (cudaPackages ? cudnn) [
cudnn.dev
cudnn.lib
] ++ lists.optionals (useSystemNccl && cudaPackages ? nccl) [
# Some platforms do not support NCCL (i.e., Jetson)
nccl.dev # Provides nccl.h AND a static copy of NCCL!
] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [