mirror of
https://github.com/ilyakooo0/nixpkgs.git
synced 2024-09-30 01:17:28 +03:00
python3Packages.torch: enable cuDNN & NCCL only if available
This commit is contained in:
parent
5a54140dff
commit
5bf016e1e9
@ -56,10 +56,7 @@
|
||||
|
||||
let
|
||||
inherit (lib) attrsets lists strings trivial;
|
||||
inherit (cudaPackages) cudaFlags cudnn;
|
||||
|
||||
# Some packages are not available on all platforms
|
||||
nccl = cudaPackages.nccl or null;
|
||||
inherit (cudaPackages) cudaFlags cudnn nccl;
|
||||
|
||||
setBool = v: if v then "1" else "0";
|
||||
|
||||
@ -212,10 +209,11 @@ in buildPythonPackage rec {
|
||||
# For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
|
||||
preConfigure = lib.optionalString cudaSupport ''
|
||||
export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
|
||||
export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
|
||||
export CUDNN_LIB_DIR=${cudnn.lib}/lib
|
||||
export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include
|
||||
export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib
|
||||
'' + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
|
||||
export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
|
||||
export CUDNN_LIB_DIR=${cudnn.lib}/lib
|
||||
'' + lib.optionalString rocmSupport ''
|
||||
export ROCM_PATH=${rocmtoolkit_joined}
|
||||
export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
|
||||
@ -273,7 +271,7 @@ in buildPythonPackage rec {
|
||||
PYTORCH_BUILD_VERSION = version;
|
||||
PYTORCH_BUILD_NUMBER = 0;
|
||||
|
||||
USE_NCCL = setBool (nccl != null);
|
||||
USE_NCCL = setBool (cudaPackages ? nccl);
|
||||
USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL
|
||||
USE_STATIC_NCCL = setBool useSystemNccl;
|
||||
|
||||
@ -348,8 +346,6 @@ in buildPythonPackage rec {
|
||||
cuda_nvrtc.lib
|
||||
cuda_nvtx.dev
|
||||
cuda_nvtx.lib # -llibNVToolsExt
|
||||
cudnn.dev
|
||||
cudnn.lib
|
||||
libcublas.dev
|
||||
libcublas.lib
|
||||
libcufft.dev
|
||||
@ -360,7 +356,10 @@ in buildPythonPackage rec {
|
||||
libcusolver.lib
|
||||
libcusparse.dev
|
||||
libcusparse.lib
|
||||
] ++ lists.optionals (nccl != null) [
|
||||
] ++ lists.optionals (cudaPackages ? cudnn) [
|
||||
cudnn.dev
|
||||
cudnn.lib
|
||||
] ++ lists.optionals (useSystemNccl && cudaPackages ? nccl) [
|
||||
# Some platforms do not support NCCL (i.e., Jetson)
|
||||
nccl.dev # Provides nccl.h AND a static copy of NCCL!
|
||||
] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
|
||||
|
Loading…
Reference in New Issue
Block a user