fix: [streamlit] optimization mode

This commit is contained in:
Thomas Mello 2022-09-14 18:22:24 +03:00
parent eb85dc4d63
commit ede81bdc5c
4 changed files with 85 additions and 69 deletions

View File

@ -31,6 +31,7 @@ general:
precision: "autocast" precision: "autocast"
optimized: False optimized: False
optimized_turbo: False optimized_turbo: False
optimized_config: "optimizedSD/v1-inference.yaml"
update_preview: True update_preview: True
update_preview_frequency: 5 update_preview_frequency: 5

View File

@ -163,15 +163,15 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
mask = torch.from_numpy(mask).to(st.session_state["device"]) mask = torch.from_numpy(mask).to(st.session_state["device"])
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
modelFS.to(st.session_state["device"] ) st.session_state.modelFS.to(st.session_state["device"] )
init_image = 2. * image - 1. init_image = 2. * image - 1.
init_image = init_image.to(st.session_state["device"]) init_image = init_image.to(st.session_state["device"])
init_latent = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).get_first_stage_encoding((st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space init_latent = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).get_first_stage_encoding((st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6 mem = torch.cuda.memory_allocated()/1e6
modelFS.to("cpu") st.session_state.modelFS.to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem): while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1) time.sleep(1)

View File

@ -163,40 +163,33 @@ def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=Fals
if "RealESRGAN" in st.session_state: if "RealESRGAN" in st.session_state:
del st.session_state["RealESRGAN"] del st.session_state["RealESRGAN"]
if "model" in st.session_state: if "model" in st.session_state:
if "model" in st.session_state and st.session_state["custom_model"] == custom_model: if "model" in st.session_state and st.session_state["custom_model"] == custom_model:
# TODO: check if the optimized mode was changed?
print("Model already loaded") print("Model already loaded")
return
else: else:
try: try:
del st.session_state["model"] del st.session_state.model
del st.session_state.modelCS
del st.session_state.modelFS
except KeyError: except KeyError:
pass pass
config = OmegaConf.load(st.session_state["defaults"].general.default_model_config) # At this point the model is either
# is not loaded yet or have been evicted:
# load new model into memory
st.session_state.custom_model = custom_model
if custom_model == st.session_state["defaults"].general.default_model: config, device, model, modelCS, modelFS = load_sd_model(custom_model)
model = load_model_from_config(config, st.session_state["defaults"].general.default_model_path)
else:
model = load_model_from_config(config, os.path.join("models","custom", f"{custom_model}.ckpt"))
st.session_state["custom_model"] = custom_model st.session_state.device = device
st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu") st.session_state.model = model
st.session_state["model"] = (model if st.session_state["defaults"].general.no_half else model.half()).to(st.session_state["device"] ) st.session_state.modelCS = modelCS
else: st.session_state.modelFS = modelFS
config = OmegaConf.load(st.session_state["defaults"].general.default_model_config)
if custom_model == st.session_state["defaults"].general.default_model: print("Model loaded.")
model = load_model_from_config(config, st.session_state["defaults"].general.default_model_path)
else:
model = load_model_from_config(config, os.path.join("models","custom", f"{custom_model}.ckpt"))
st.session_state["custom_model"] = custom_model
st.session_state["device"] = torch.device(f"cuda:{st.session_state['defaults'].general.gpu}") if torch.cuda.is_available() else torch.device("cpu")
st.session_state["model"] = (model if st.session_state['defaults'].general.no_half else model.half()).to(st.session_state["device"] )
print("Model loaded.")
def load_model_from_config(config, ckpt, verbose=False): def load_model_from_config(config, ckpt, verbose=False):
@ -220,6 +213,7 @@ def load_model_from_config(config, ckpt, verbose=False):
model.eval() model.eval()
return model return model
def load_sd_from_config(ckpt, verbose=False): def load_sd_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
@ -681,18 +675,26 @@ def try_loading_LDSR(model_name: str,checking=False):
#try_loading_LDSR('model',checking=True) #try_loading_LDSR('model',checking=True)
def load_SD_model():
if st.session_state['defaults'].general.optimized: # Loads Stable Diffusion model by name
sd = load_sd_from_config(st.session_state['defaults'].general.default_model_path) def load_sd_model(model_name: str) -> [any, any, any, any, any]:
ckpt_path = st.session_state.defaults.general.default_model_path
if model_name != st.session_state.defaults.general.default_model:
ckpt_path = os.path.join("models", "custom", f"{model_name}.ckpt")
if st.session_state.defaults.general.optimized:
config = OmegaConf.load(st.session_state.defaults.general.optimized_config)
sd = load_sd_from_config(ckpt_path)
li, lo = [], [] li, lo = [], []
for key, v_ in sd.items(): for key, v_ in sd.items():
sp = key.split('.') sp = key.split('.')
if(sp[0]) == 'model': if (sp[0]) == 'model':
if('input_blocks' in sp): if 'input_blocks' in sp:
li.append(key) li.append(key)
elif('middle_block' in sp): elif 'middle_block' in sp:
li.append(key) li.append(key)
elif('time_embed' in sp): elif 'time_embed' in sp:
li.append(key) li.append(key)
else: else:
lo.append(key) lo.append(key)
@ -701,14 +703,14 @@ def load_SD_model():
for key in lo: for key in lo:
sd['model2.' + key[6:]] = sd.pop(key) sd['model2.' + key[6:]] = sd.pop(key)
config = OmegaConf.load("optimizedSD/v1-inference.yaml") device = torch.device(f"cuda:{st.session_state.defaults.general.gpu}") \
device = torch.device(f"cuda:{opt.gpu}") if torch.cuda.is_available() else torch.device("cpu") if torch.cuda.is_available() else torch.device("cpu")
model = instantiate_from_config(config.modelUNet) model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False) _, _ = model.load_state_dict(sd, strict=False)
model.cuda() model.cuda()
model.eval() model.eval()
model.turbo = st.session_state['defaults'].general.optimized_turbo model.turbo = st.session_state.defaults.general.optimized_turbo
modelCS = instantiate_from_config(config.modelCondStage) modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False) _, _ = modelCS.load_state_dict(sd, strict=False)
@ -721,22 +723,25 @@ def load_SD_model():
del sd del sd
if not st.session_state['defaults'].general.no_half: if not st.session_state.defaults.general.no_half:
model = model.half() model = model.half()
modelCS = modelCS.half() modelCS = modelCS.half()
modelFS = modelFS.half() modelFS = modelFS.half()
return model,modelCS,modelFS,device, config
return config, device, model, modelCS, modelFS
else: else:
config = OmegaConf.load(st.session_state['defaults'].general.default_model_config) config = OmegaConf.load(st.session_state.defaults.general.default_model_config)
model = load_model_from_config(config, st.session_state['defaults'].general.default_model_path) model = load_model_from_config(config, ckpt_path)
device = torch.device(f"cuda:{opt.gpu}") if torch.cuda.is_available() else torch.device("cpu") device = torch.device(f"cuda:{st.session_state.defaults.general.gpu}") \
model = (model if st.session_state['defaults'].general.no_half else model.half()).to(device) if torch.cuda.is_available() else torch.device("cpu")
return model, device,config model = (model if st.session_state.defaults.general.no_half
else model.half()).to(device)
# return config, device, model, None, None
#
# @codedealer: No usages
def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'): def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'):
#get global variables #get global variables
global_vars = globals() global_vars = globals()
@ -750,8 +755,8 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
if m == 'model': if m == 'model':
del global_vars[m+'FS'] del global_vars[m+'FS']
del global_vars[m+'CS'] del global_vars[m+'CS']
if m =='model': if m == 'model':
m='Stable Diffusion' m = 'Stable Diffusion'
print('Unloaded ' + m) print('Unloaded ' + m)
if load: if load:
for m in models: for m in models:
@ -792,11 +797,11 @@ def generation_callback(img, i=0):
# It can probably be done in a better way for someone who knows what they're doing. I don't. # It can probably be done in a better way for someone who knows what they're doing. I don't.
#print (img,isinstance(img, torch.Tensor)) #print (img,isinstance(img, torch.Tensor))
if isinstance(img, torch.Tensor): if isinstance(img, torch.Tensor):
x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).decode_first_stage(img) x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(img)
else: else:
# When using the k Diffusion samplers they return a dict instead of a tensor that look like this: # When using the k Diffusion samplers they return a dict instead of a tensor that look like this:
# {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}
x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).decode_first_stage(img["denoised"]) x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(img["denoised"])
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@ -1025,10 +1030,10 @@ def draw_prompt_matrix(im, width, height, all_prompts):
def check_prompt_length(prompt, comments): def check_prompt_length(prompt, comments):
"""this function tests if prompt is too long, and if so, adds a message to comments""" """this function tests if prompt is too long, and if so, adds a message to comments"""
tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelCS).cond_stage_model.tokenizer tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer
max_length = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelCS).cond_stage_model.max_length max_length = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.max_length
info = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, info = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
return_overflowing_tokens=True, padding="max_length", return_tensors="pt") return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
ovf = info['overflowing_tokens'][0] ovf = info['overflowing_tokens'][0]
overflowing_count = ovf.shape[0] overflowing_count = ovf.shape[0]
@ -1322,9 +1327,9 @@ def process_images(
print(prompt) print(prompt)
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
modelCS.to(st.session_state['defaults'].general.gpu) st.session_state.modelCS.to(st.session_state['defaults'].general.gpu)
uc = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelCS).get_learned_conditioning(len(prompts) * [""]) uc = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(len(prompts) * [""])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
@ -1338,16 +1343,16 @@ def process_images(
c = torch.zeros_like(uc) # i dont know if this is correct.. but it works c = torch.zeros_like(uc) # i dont know if this is correct.. but it works
for i in range(0, len(weighted_subprompts)): for i in range(0, len(weighted_subprompts)):
# note if alpha negative, it functions same as torch.sub # note if alpha negative, it functions same as torch.sub
c = torch.add(c, (st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1]) c = torch.add(c, (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1])
else: # just behave like usual else: # just behave like usual
c = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelCS).get_learned_conditioning(prompts) c = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(prompts)
shape = [opt_C, height // opt_f, width // opt_f] shape = [opt_C, height // opt_f, width // opt_f]
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6 mem = torch.cuda.memory_allocated()/1e6
modelCS.to("cpu") st.session_state.modelCS.to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem): while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1) time.sleep(1)
@ -1376,9 +1381,9 @@ def process_images(
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
modelFS.to(st.session_state['defaults'].general.gpu) st.session_state.modelFS.to(st.session_state['defaults'].general.gpu)
x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).decode_first_stage(samples_ddim) x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for i, x_sample in enumerate(x_samples_ddim): for i, x_sample in enumerate(x_samples_ddim):
@ -1512,7 +1517,7 @@ def process_images(
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6 mem = torch.cuda.memory_allocated()/1e6
modelFS.to("cpu") st.session_state.modelFS.to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem): while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1) time.sleep(1)

View File

@ -162,16 +162,26 @@ def layout():
cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.") seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
#batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2img.batch_size, step=1,
#help="How many images are at once in a batch.\ bs_slider_max_value = 5
#It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ if st.session_state.defaults.general.optimized:
#Default: 1") bs_slider_max_value = 100
batch_size = st.slider(
"Batch size",
min_value=1,
max_value=bs_slider_max_value,
value=st.session_state.defaults.txt2img.batch_size,
step=1,
help="How many images are at once in a batch.\
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
Default: 1")
with st.expander("Preview Settings"): with st.expander("Preview Settings"):
st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2img.update_preview, st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2img.update_preview,
help="If enabled the image preview will be updated during the generation instead of at the end. \ help="If enabled the image preview will be updated during the generation instead of at the end. \
You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
By default this is enabled and the frequency is set to 1 step.") By default this is enabled and the frequency is set to 1 step.")
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency, st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency,
help="Frequency in steps at which the the preview image is updated. By default the frequency \ help="Frequency in steps at which the the preview image is updated. By default the frequency \
@ -244,9 +254,9 @@ def layout():
load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model) load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model)
try: try:
output_images, seeds, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1, output_images, seeds, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, batch_size,
cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp, save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=st.session_state.defaults.general.fp,
variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files)
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="") message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="")