call torch_gc before/after each gpu gradio operation

This commit is contained in:
AUTOMATIC 2022-09-29 11:32:12 +03:00
parent c1c27dad3b
commit 2f2d356e4c

View File

@ -1,6 +1,7 @@
import os import os
import threading import threading
from modules import devices
from modules.paths import script_path from modules.paths import script_path
import signal import signal
@ -47,6 +48,8 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func): def wrap_gradio_gpu_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
devices.torch_gc()
shared.state.sampling_step = 0 shared.state.sampling_step = 0
shared.state.job_count = -1 shared.state.job_count = -1
shared.state.job_no = 0 shared.state.job_no = 0
@ -62,6 +65,8 @@ def wrap_gradio_gpu_call(func):
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
devices.torch_gc()
return res return res
return modules.ui.wrap_gradio_call(f) return modules.ui.wrap_gradio_call(f)