mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 22:53:25 +03:00
support for generating images on video cards with 4GB
This commit is contained in:
parent
7a7a3a6b19
commit
9c9f048b5e
90
webui.py
90
webui.py
@ -2,6 +2,8 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -51,6 +53,7 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not
|
|||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
|
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
|
||||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||||
|
parser.add_argument("--lowvram", action='store_true', help="enamble optimizations for low vram")
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
|
|
||||||
@ -185,11 +188,80 @@ def load_model_from_config(config, ckpt, verbose=False):
|
|||||||
print("unexpected keys:")
|
print("unexpected keys:")
|
||||||
print(u)
|
print(u)
|
||||||
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
module_in_gpu = None
|
||||||
|
|
||||||
|
|
||||||
|
def setup_for_low_vram(sd_model):
|
||||||
|
parents = {}
|
||||||
|
|
||||||
|
def send_me_to_gpu(module, _):
|
||||||
|
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
|
||||||
|
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
|
||||||
|
be in CPU
|
||||||
|
"""
|
||||||
|
global module_in_gpu
|
||||||
|
|
||||||
|
module = parents.get(module, module)
|
||||||
|
|
||||||
|
if module_in_gpu == module:
|
||||||
|
return
|
||||||
|
|
||||||
|
if module_in_gpu is not None:
|
||||||
|
print('removing from gpu:', type(module_in_gpu))
|
||||||
|
module_in_gpu.to(cpu)
|
||||||
|
|
||||||
|
print('adding to gpu:', type(module))
|
||||||
|
module.to(gpu)
|
||||||
|
|
||||||
|
print('added to gpu:', type(module))
|
||||||
|
module_in_gpu = module
|
||||||
|
|
||||||
|
# see below for register_forward_pre_hook;
|
||||||
|
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
||||||
|
# useless here, and we just replace those methods
|
||||||
|
def first_stage_model_encode_wrap(self, encoder, x):
|
||||||
|
send_me_to_gpu(self, None)
|
||||||
|
return encoder(x)
|
||||||
|
|
||||||
|
def first_stage_model_decode_wrap(self, decoder, z):
|
||||||
|
send_me_to_gpu(self, None)
|
||||||
|
return decoder(z)
|
||||||
|
|
||||||
|
# remove three big modules, cond, first_stage, and unet from the model and then
|
||||||
|
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||||
|
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
||||||
|
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
||||||
|
sd_model.to(device)
|
||||||
|
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
||||||
|
|
||||||
|
# register hooks for those the first two models
|
||||||
|
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
|
||||||
|
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
|
||||||
|
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||||
|
|
||||||
|
# the third remaining model is still too big for 4GB, so we also do the same for its submodules
|
||||||
|
# so that only one of them is in GPU at a time
|
||||||
|
diff_model = sd_model.model.diffusion_model
|
||||||
|
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
||||||
|
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
||||||
|
sd_model.model.to(device)
|
||||||
|
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
||||||
|
|
||||||
|
# install hooks for bits of third model
|
||||||
|
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
for block in diff_model.input_blocks:
|
||||||
|
block.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
for block in diff_model.output_blocks:
|
||||||
|
block.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
|
||||||
|
|
||||||
def create_random_tensors(shape, seeds):
|
def create_random_tensors(shape, seeds):
|
||||||
xs = []
|
xs = []
|
||||||
for seed in seeds:
|
for seed in seeds:
|
||||||
@ -838,7 +910,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
|
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
|
||||||
|
|
||||||
output_images = []
|
output_images = []
|
||||||
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
ema_scope = (nullcontext if cmd_opts.lowvram else model.ema_scope)
|
||||||
|
with torch.no_grad(), autocast("cuda"), ema_scope():
|
||||||
p.init()
|
p.init()
|
||||||
|
|
||||||
for n in range(p.n_iter):
|
for n in range(p.n_iter):
|
||||||
@ -1327,8 +1400,17 @@ interfaces = [
|
|||||||
sd_config = OmegaConf.load(cmd_opts.config)
|
sd_config = OmegaConf.load(cmd_opts.config)
|
||||||
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
|
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device)
|
gpu = torch.device("cuda")
|
||||||
|
device = gpu if torch.cuda.is_available() else cpu
|
||||||
|
|
||||||
|
sd_model = (sd_model if cmd_opts.no_half else sd_model.half())
|
||||||
|
|
||||||
|
if not cmd_opts.lowvram:
|
||||||
|
sd_model = sd_model.to(device)
|
||||||
|
|
||||||
|
else:
|
||||||
|
setup_for_low_vram(sd_model)
|
||||||
|
|
||||||
model_hijack = StableDiffusionModelHijack()
|
model_hijack = StableDiffusionModelHijack()
|
||||||
model_hijack.hijack(sd_model)
|
model_hijack.hijack(sd_model)
|
||||||
|
Loading…
Reference in New Issue
Block a user