diff --git a/pkgs/development/python-modules/pytorch/default.nix b/pkgs/development/python-modules/pytorch/default.nix index 35eb79d8b2d3..0de0015ab1ec 100644 --- a/pkgs/development/python-modules/pytorch/default.nix +++ b/pkgs/development/python-modules/pytorch/default.nix @@ -301,6 +301,11 @@ in buildPythonPackage rec { # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder. requiredSystemFeatures = [ "big-parallel" ]; + passthru = { + inherit cudaSupport; + cudaArchList = final_cudaArchList; + }; + meta = with lib; { description = "Open source, prototype-to-production deep learning platform"; homepage = "https://pytorch.org/"; diff --git a/pkgs/development/python-modules/torchvision/default.nix b/pkgs/development/python-modules/torchvision/default.nix index a42c517ede96..fc9905881cb6 100644 --- a/pkgs/development/python-modules/torchvision/default.nix +++ b/pkgs/development/python-modules/torchvision/default.nix @@ -1,4 +1,5 @@ { lib +, symlinkJoin , buildPythonPackage , fetchFromGitHub , ninja @@ -10,9 +11,18 @@ , pillow , pytorch , pytest +, cudatoolkit +, cudnn +, cudaSupport ? pytorch.cudaSupport or false # by default uses the value from pytorch }: -buildPythonPackage rec { +let + cudatoolkit_joined = symlinkJoin { + name = "${cudatoolkit.name}-unsplit"; + paths = [ cudatoolkit.out cudatoolkit.lib ]; + }; + cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" pytorch.cudaArchList; +in buildPythonPackage rec { pname = "torchvision"; version = "0.10.0"; @@ -23,15 +33,22 @@ buildPythonPackage rec { sha256 = "13j04ij0jmi58nhav1p69xrm8dg7jisg23268i3n6lnms37n02kc"; }; - nativeBuildInputs = [ libpng ninja which ]; + nativeBuildInputs = [ libpng ninja which ] + ++ lib.optionals cudaSupport [ cudatoolkit_joined ]; TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/"; TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/"; - buildInputs = [ libjpeg_turbo libpng ]; + buildInputs = [ libjpeg_turbo libpng ] + ++ lib.optionals cudaSupport [ cudnn ]; propagatedBuildInputs = [ numpy pillow pytorch scipy ]; + preBuild = lib.optionalString cudaSupport '' + export TORCH_CUDA_ARCH_LIST="${cudaArchStr}" + export FORCE_CUDA=1 + ''; + # tries to download many datasets for tests doCheck = false; @@ -45,6 +62,7 @@ buildPythonPackage rec { description = "PyTorch vision library"; homepage = "https://pytorch.org/"; license = licenses.bsd3; + platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin; maintainers = with maintainers; [ ericsagnes ]; }; }