Added option to use cudnn as backend for pytorch, this should help fixing an issue with nvidia 16xx cards getting a black or green square instead of a proper image. (#1699)

This commit is contained in:
Alejandro Gil 2022-12-03 04:58:57 -08:00 committed by GitHub
commit 53086b888d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 4 deletions

View File

@ -59,6 +59,7 @@ general:
no_half: False
use_float16: False
precision: "autocast"
use_cudnn: False
optimized: False
optimized_turbo: False
optimized_config: "optimizedSD/v1-inference.yaml"

View File

@ -23,13 +23,13 @@ channels:
dependencies:
- conda-forge::nodejs=18.11.0
- yarn=1.22.19
- cudatoolkit=11.3
- cudatoolkit=11.7
- git
- numpy=1.22.3
- numpy=1.23.3
- pip=20.3
- python=3.8.5
- pytorch=1.11.0
- pytorch=1.13.0
- scikit-image=0.19.2
- torchvision=0.12.0
- torchvision=0.14.0
- pip:
- -r requirements.txt

View File

@ -95,6 +95,12 @@ except ImportError as e:
# remove all the annoying python warnings.
shutup.please()
# the following lines should help fixing an issue with nvidia 16xx cards.
if "defaults" in st.session_state:
if st.session_state["defaults"].general.use_cudnn:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
@ -1613,6 +1619,10 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
#
@retry(tries=5)
def generation_callback(img, i=0):
# try to do garbage collection before decoding the image
torch_gc()
if "update_preview_frequency" not in st.session_state:
raise StopException