Moved the models to use the server_state component instead of session_state so it can be shared between multiple sessions, tabs and users as long as the streamlit server is running.

Moved the models to use the server_state component instead of session_state so it can be shared between multiple sessions, tabs and users as long as the streamlit server is running.
This commit is contained in:
ZeroCool 2022-09-25 04:03:48 -07:00 committed by GitHub
commit f4c8b9500f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 744 additions and 465 deletions

View File

@ -273,8 +273,13 @@ img2img:
variant_seed: ""
write_info_files: True
concepts_library:
concepts_per_page: 12
gfpgan:
strength: 100
textual_inversion:
value: 0

View File

@ -42,6 +42,7 @@ dependencies:
- streamlit-on-Hover-tabs==1.0.1
- streamlit-option-menu==0.3.2
- streamlit_nested_layout
- streamlit-server-state==0.14.2
- test-tube>=0.7.5
- tensorboard
- torch-fidelity==0.3.0

View File

@ -6,6 +6,7 @@ from sd_utils import *
#streamlit components section
import streamlit_nested_layout
from streamlit_server_state import server_state, server_state_lock
#other imports
from omegaconf import OmegaConf
@ -17,9 +18,11 @@ def layout():
st.header("Settings")
with st.form("Settings"):
general_tab, txt2img_tab, img2img_tab, txt2vid_tab, textual_inversion_tab = st.tabs(['General', "Text-To-Image",
general_tab, txt2img_tab, img2img_tab, \
txt2vid_tab, textual_inversion_tab, concepts_library_tab = st.tabs(['General', "Text-To-Image",
"Image-To-Image", "Text-To-Video",
"Textual Inversion"])
"Textual Inversion",
"Concepts Library"])
with general_tab:
col1, col2, col3, col4, col5 = st.columns(5, gap='large')
@ -47,8 +50,8 @@ def layout():
custom_models_available()
if st.session_state.CustomModel_available:
st.session_state.default_model = st.selectbox("Default Model:", st.session_state.custom_models,
index=st.session_state.custom_models.index(st.session_state['defaults'].general.default_model),
st.session_state.default_model = st.selectbox("Default Model:", server_state["custom_models"],
index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
help="Select the model you want to use. If you have placed custom models \
on your 'models/custom' folder they will be shown here as well. The model name that will be shown here \
is the same as the name the file for the model has on said folder, \
@ -197,6 +200,14 @@ def layout():
st.title("Textual Inversion")
st.info("Under Construction. :construction_worker:")
with concepts_library_tab:
st.title("Concepts Library")
#st.info("Under Construction. :construction_worker:")
col1, col2, col3, col4, col5 = st.columns(5, gap='large')
with col1:
st.session_state["defaults"].concepts_library.concepts_per_page = int(st.text_input("Concepts Per Page", value=st.session_state['defaults'].concepts_library.concepts_per_page,
help="Number of concepts per page to show on the Concepts Library. Default: '12'"))
# add space for the buttons at the bottom
st.markdown("---")

View File

@ -28,8 +28,6 @@ from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from slugify import slugify
import json
import os
import sys
logger = get_logger(__name__)
@ -40,7 +38,6 @@ def parse_args():
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
@ -50,17 +47,16 @@ def parse_args():
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
"--train_data_dir", type=str, default=None, help="A folder containing the training data."
)
parser.add_argument(
"--placeholder_token",
type=str,
default=None,
required=True,
help="A token to use as a placeholder for the concept.",
)
parser.add_argument(
"--initializer_token", type=str, default=None, required=True,help="A token to use as initializer word."
"--initializer_token", type=str, default=None, help="A token to use as initializer word."
)
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
@ -68,7 +64,6 @@ def parse_args():
"--output_dir",
type=str,
default="text-inversion-model",
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@ -128,6 +123,31 @@ def parse_args():
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument(
"--use_auth_token",
action="store_true",
help=(
"Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
" private models)."
),
)
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
@ -144,14 +164,32 @@ def parse_args():
"--checkpoint_frequency",
type=int,
default=500,
help="How often to save a checkpoint",
help="How often to save a checkpoint and sample image",
)
parser.add_argument(
"--stable_sample_batches",
type=int,
default=0,
help="Number of fixed seed sample batches to generate per checkpoint",
)
parser.add_argument(
"--random_sample_batches",
type=int,
default=1,
help="Number of random seed sample batches to generate per checkpoint",
)
parser.add_argument(
"--sample_batch_size",
type=int,
default=1,
help="Number of samples to generate per batch",
)
parser.add_argument(
"--custom_templates",
type=str,
default=None,
help=(
"A comma-delimited list of custom templates to use"
"A comma-delimited list of custom template to use for samples, using {} as a placeholder for the concept."
),
)
parser.add_argument(
@ -166,19 +204,10 @@ def parse_args():
default=None,
help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
)
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to a JSON config file specifying the arguments to use. If resume_from is given, it is automatically inferred."
)
args = parser.parse_args()
if args.config is not None:
with open(args.config, 'rt') as f:
args = parser.parse_args(namespace=argparse.Namespace(**json.load(f)))
elif args.resume_from is not None:
with open(f"{args.resume_from}/resume.json", 'rt') as f:
if args.resume_from is not None:
with open(Path(args.resume_from) / "resume.json", 'rt') as f:
args = parser.parse_args(namespace=argparse.Namespace(**json.load(f)["args"]))
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@ -277,10 +306,10 @@ class TextualInversionDataset(Dataset):
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
}[interpolation]
self.templates = templates
@ -329,6 +358,16 @@ class TextualInversionDataset(Dataset):
return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def freeze_params(params):
for param in params:
param.requires_grad = False
@ -337,7 +376,7 @@ def freeze_params(params):
def save_resume_file(basepath, args, extra = {}):
info = {"args": vars(args)}
info["args"].update(extra)
with open(f"{basepath}/resume.json", "w") as f:
with open(Path(basepath) / "resume.json", "w") as f:
json.dump(info, f, indent=4)
@ -352,6 +391,9 @@ class Checkpointer:
placeholder_token_id,
templates,
output_dir,
random_sample_batches,
sample_batch_size,
stable_sample_batches,
seed
):
self.accelerator = accelerator
@ -362,14 +404,17 @@ class Checkpointer:
self.placeholder_token_id = placeholder_token_id
self.templates = templates
self.output_dir = output_dir
self.random_sample_batches = random_sample_batches
self.sample_batch_size = sample_batch_size
self.stable_sample_batches = stable_sample_batches
self.seed = seed
def checkpoint(self, step, text_encoder):
def checkpoint(self, step, text_encoder, save_samples=True):
print("Saving checkpoint for step %d..." % step)
with torch.autocast("cuda"):
checkpoints_path = f"{self.output_dir}/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)
checkpoints_path = self.output_dir / "checkpoints"
checkpoints_path.mkdir(exist_ok=True, parents=True)
unwrapped = self.accelerator.unwrap_model(text_encoder)
@ -378,27 +423,94 @@ class Checkpointer:
learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
filename = f"learned_embeds_%s_%d.bin" % (slugify(self.placeholder_token), step)
torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
torch.save(learned_embeds_dict, checkpoints_path / filename)
torch.save(learned_embeds_dict, checkpoints_path / "last.bin")
del unwrapped
return f"{checkpoints_path}/last.bin"
return checkpoints_path / "last.bin"
def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
samples_path = self.output_dir / "samples"
samples_path.mkdir(exist_ok=True, parents=True)
checker = NoCheck()
with torch.autocast("cuda"):
unwrapped = self.accelerator.unwrap_model(text_encoder)
# Save a sample image
pipeline = StableDiffusionPipeline(
text_encoder=unwrapped,
vae=self.vae,
unet=self.unet,
tokenizer=self.tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=NoCheck(),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
).to('cuda')
pipeline.enable_attention_slicing()
if self.stable_sample_batches > 0:
stable_latents = torch.randn(
(self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
device=pipeline.device,
generator=torch.Generator(device=pipeline.device).manual_seed(self.seed),
)
stable_prompts = [choice.format(self.placeholder_token) for choice in (self.templates * self.sample_batch_size)[:self.sample_batch_size]]
# Generate and save stable samples
for i in range(0, self.stable_sample_batches):
samples = pipeline(
prompt=stable_prompts,
height=max(512, height),
latents=stable_latents,
width=max(512, width),
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type='pil'
)["sample"]
for idx, im in enumerate(samples):
filename = f"stable_sample_%d_%d_step_%d.png" % (i+1, idx+1, step)
im.save(samples_path / filename)
prompts = [choice.format(self.placeholder_token) for choice in random.choices(self.templates, k=self.sample_batch_size)]
# Generate and save random samples
for i in range(0, self.random_sample_batches):
samples = pipeline(
prompt=prompts,
height=max(512, height),
width=max(512, width),
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type='pil'
)["sample"]
for idx, im in enumerate(samples):
filename = f"step_%d_sample_%d_%d.png" % (step, i+1, idx+1)
im.save(samples_path / filename)
del im
del pipeline
del unwrapped
def main():
args = parse_args()
global_step_offset = 0
if args.resume_from is not None:
basepath = f"{args.resume_from}"
basepath = Path(args.resume_from)
print("Resuming state from %s" % args.resume_from)
with open(f"{basepath}/resume.json", 'r') as f:
with open(basepath / "resume.json", 'r') as f:
state = json.load(f)
global_step_offset = state["args"]["global_step"]
print("We've trained %d steps so far" % global_step_offset)
else:
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}"
os.makedirs(basepath, exist_ok=True)
basepath = Path(args.logging_dir) / slugify(args.placeholder_token)
basepath.mkdir(exist_ok=True, parents=True)
accelerator = Accelerator(
@ -410,6 +522,23 @@ def main():
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer and add the placeholder token as a additional special token
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
@ -487,6 +616,9 @@ def main():
placeholder_token_id=placeholder_token_id,
templates=templates,
output_dir=basepath,
sample_batch_size=args.sample_batch_size,
random_sample_batches=args.random_sample_batches,
stable_sample_batches=args.stable_sample_batches,
seed=args.seed
)
@ -518,7 +650,7 @@ def main():
learnable_property=args.learnable_property,
center_crop=args.center_crop,
set="train",
templates=templates
templates=base_templates
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
@ -628,6 +760,8 @@ def main():
"global_step": global_step + global_step_offset,
"resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin")
})
checkpointer.save_samples(global_step + global_step_offset, text_encoder,
args.resolution, args.resolution, 7.5, 0.0, 25)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@ -640,16 +774,36 @@ def main():
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline(
text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
#pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, f"{basepath}/learned_embeds.bin")
torch.save(learned_embeds_dict, basepath / f"learned_embeds.bin")
if global_step % args.checkpoint_frequency != 0:
checkpointer.save_samples(global_step + global_step_offset, text_encoder,
args.resolution, args.resolution, 7.5, 0.0, 25)
print("Saving resume state")
save_resume_file(basepath, args, {
"global_step": global_step + global_step_offset,
"resume_checkpoint": f"{basepath}/checkpoints/last.bin"
"resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin")
})
if args.push_to_hub:
repo.push_to_hub(
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()
except KeyboardInterrupt:
@ -658,7 +812,7 @@ def main():
checkpointer.checkpoint(global_step + global_step_offset, text_encoder)
save_resume_file(basepath, args, {
"global_step": global_step + global_step_offset,
"resume_checkpoint": f"{basepath}/checkpoints/last.bin"
"resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin")
})
quit()

View File

@ -65,21 +65,21 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
#use_RealESRGAN = 11 in toggles
if sampler_name == 'PLMS':
sampler = PLMSSampler(st.session_state["model"])
sampler = PLMSSampler(server_state["model"])
elif sampler_name == 'DDIM':
sampler = DDIMSampler(st.session_state["model"])
sampler = DDIMSampler(server_state["model"])
elif sampler_name == 'k_dpm_2_a':
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral')
elif sampler_name == 'k_dpm_2':
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
sampler = KDiffusionSampler(server_state["model"],'dpm_2')
elif sampler_name == 'k_euler_a':
sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
sampler = KDiffusionSampler(server_state["model"],'euler_ancestral')
elif sampler_name == 'k_euler':
sampler = KDiffusionSampler(st.session_state["model"],'euler')
sampler = KDiffusionSampler(server_state["model"],'euler')
elif sampler_name == 'k_heun':
sampler = KDiffusionSampler(st.session_state["model"],'heun')
sampler = KDiffusionSampler(server_state["model"],'heun')
elif sampler_name == 'k_lms':
sampler = KDiffusionSampler(st.session_state["model"],'lms')
sampler = KDiffusionSampler(server_state["model"],'lms')
else:
raise Exception("Unknown sampler: " + sampler_name)
@ -160,18 +160,18 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
mask = (1 - mask)
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3)
mask = torch.from_numpy(mask).to(st.session_state["device"])
mask = torch.from_numpy(mask).to(server_state["device"])
if st.session_state['defaults'].general.optimized:
st.session_state.modelFS.to(st.session_state["device"] )
server_state["modelFS"].to(server_state["device"] )
init_image = 2. * image - 1.
init_image = init_image.to(st.session_state["device"])
init_latent = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).get_first_stage_encoding((st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
init_image = init_image.to(server_state["device"])
init_latent = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).get_first_stage_encoding((server_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6
st.session_state.modelFS.to("cpu")
server_state["modelFS"].to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
@ -208,7 +208,7 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
x0, z_mask = init_data
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] ))
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(server_state["device"] ))
# Obliterate masked image
if z_mask is not None and obliterate:
@ -377,8 +377,8 @@ def layout():
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
custom_models_available()
if st.session_state["CustomModel_available"]:
st.session_state["custom_model"] = st.selectbox("Custom Model:", st.session_state["custom_models"],
index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model),
st.session_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"],
index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
help="Select the model you want to use. This option is only available if you have custom models \
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
@ -442,13 +442,13 @@ def layout():
help="Save a file next to the image with informartion about the generation.")
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.")
if st.session_state["GFPGAN_available"]:
if server_state["GFPGAN_available"]:
use_GFPGAN = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
else:
use_GFPGAN = False
if st.session_state["RealESRGAN_available"]:
if server_state["RealESRGAN_available"]:
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN,
help="Uses the RealESRGAN model to upscale the images after the generation.\
This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")

View File

@ -22,7 +22,7 @@ def sdConceptsBrowser(concepts, key=None):
return component_value
@st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True)
@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
def getConceptsFromPath(page, conceptPerPage, searchText=""):
#print("getConceptsFromPath", "page:", page, "conceptPerPage:", conceptPerPage, "searchText:", searchText)
# get the path where the concepts are stored
@ -97,7 +97,6 @@ def getConceptsFromPath(page, conceptPerPage, searchText=""):
#print("Results:", [c["name"] for c in concepts])
return concepts
@st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True)
def imageToBase64(image):
import io
@ -108,7 +107,7 @@ def imageToBase64(image):
return img_str
@st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True)
@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
def getTotalNumberOfConcepts(searchText=""):
# get the path where the concepts are stored
path = os.path.join(
@ -138,7 +137,7 @@ def layout():
# Concept Library
with tab_library:
downloaded_concepts_count = getTotalNumberOfConcepts()
concepts_per_page = 12
concepts_per_page = st.session_state["defaults"].concepts_library.concepts_per_page
if not "results" in st.session_state:
st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, "")
@ -178,7 +177,7 @@ def layout():
# Previous page
with _previous_page:
if st.button("<", key="cl_previous_page"):
if st.button("Previous", key="cl_previous_page"):
st.session_state["cl_current_page"] -= 1
if st.session_state["cl_current_page"] <= 0:
st.session_state["cl_current_page"] = last_page
@ -190,7 +189,7 @@ def layout():
# Next page
with _next_page:
if st.button(">", key="cl_next_page"):
if st.button("Next", key="cl_next_page"):
st.session_state["cl_current_page"] += 1
if st.session_state["cl_current_page"] > last_page:
st.session_state["cl_current_page"] = 1

View File

@ -4,6 +4,10 @@ from webui_streamlit import st
# streamlit imports
from streamlit import StopException
#streamlit components section
from streamlit_server_state import server_state, server_state_lock
#other imports
import warnings
@ -55,6 +59,7 @@ except:
# remove some annoying deprecation warnings that show every now and then.
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
@ -153,15 +158,17 @@ def human_readable_size(size, decimal_places=3):
size /= 1024.0
return f"{size:.{decimal_places}f}{unit}"
@retry(tries=5)
def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus",
CustomModel_available=False, custom_model="Stable Diffusion v1.4"):
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
print ("Loading models.")
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].text("Loading models...")
# Generate random run ID
# Used to link runs linked w/ continue_prev_run which is not yet implemented
# Use URL and filesystem safe version just in case.
@ -171,83 +178,95 @@ def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=Fals
# check what models we want to use and if the they are already loaded.
with server_state_lock["GFPGAN"]:
if use_GFPGAN:
if "GFPGAN" in st.session_state:
if "GFPGAN" in server_state:
print("GFPGAN already loaded")
else:
# Load GFPGAN
if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
try:
st.session_state["GFPGAN"] = load_GFPGAN()
server_state["GFPGAN"] = load_GFPGAN()
print("Loaded GFPGAN")
except Exception:
import traceback
print("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
if "GFPGAN" in st.session_state:
del st.session_state["GFPGAN"]
if "GFPGAN" in server_state:
del server_state["GFPGAN"]
with server_state_lock["RealESRGAN"]:
if use_RealESRGAN:
if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model:
if "RealESRGAN" in server_state and server_state["RealESRGAN"].model.name == RealESRGAN_model:
print("RealESRGAN already loaded")
else:
#Load RealESRGAN
try:
# We first remove the variable in case it has something there,
# some errors can load the model incorrectly and leave things in memory.
del st.session_state["RealESRGAN"]
del server_state["RealESRGAN"]
except KeyError:
pass
if os.path.exists(st.session_state["defaults"].general.RealESRGAN_dir):
# st.session_state is used for keeping the models in memory across multiple pages or runs.
st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model)
print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name)
server_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model)
print("Loaded RealESRGAN with model "+ server_state["RealESRGAN"].model.name)
else:
if "RealESRGAN" in st.session_state:
del st.session_state["RealESRGAN"]
if "RealESRGAN" in server_state:
del server_state["RealESRGAN"]
if "model" in st.session_state:
if "model" in st.session_state and st.session_state["loaded_model"] == custom_model:
with server_state_lock["model"], server_state_lock["modelCS"], server_state_lock["modelFS"], server_state_lock["loaded_model"]:
if "model" in server_state:
if "model" in server_state and server_state["loaded_model"] == custom_model:
# TODO: check if the optimized mode was changed?
if "pipe" in st.session_state:
del st.session_state.pipe
print("Model already loaded")
return
else:
try:
del st.session_state.model
del st.session_state.modelCS
del st.session_state.modelFS
del st.session_state.loaded_model
del server_state["model"]
del server_state["modelCS"]
del server_state["modelFS"]
del server_state["loaded_model"]
if "pipe" in st.session_state:
del st.session_state.pipe
except KeyError:
pass
# if the model from txt2vid is in memory we need to remove it to improve performance.
with server_state_lock["pipe"]:
if "pipe" in server_state:
del server_state["pipe"]
# At this point the model is either
# is not loaded yet or have been evicted:
# load new model into memory
st.session_state.custom_model = custom_model
server_state["custom_model"] = custom_model
config, device, model, modelCS, modelFS = load_sd_model(custom_model)
st.session_state.device = device
st.session_state.model = model
st.session_state.modelCS = modelCS
st.session_state.modelFS = modelFS
st.session_state.loaded_model = custom_model
server_state["device"] = device
server_state["model"] = model
server_state["modelCS"] = modelCS
server_state["modelFS"] = modelFS
server_state["loaded_model"] = custom_model
#trying to disable multiprocessing as it makes it so streamlit cant stop when the
# model is loaded in memory and you need to kill the process sometimes.
try:
server_state["model"].args.use_multiprocessing_for_evaluation = False
except:
pass
if st.session_state.defaults.general.enable_attention_slicing:
st.session_state.model.enable_attention_slicing()
server_state["model"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage:
st.session_state.model.enable_minimal_memory_usage()
server_state["model"].enable_minimal_memory_usage()
print("Model loaded.")
@ -584,6 +603,7 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:
raise ValueError(f'Order {order} too high for step {i}')
@ -656,6 +676,7 @@ def torch_gc():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
@retry(tries=5)
def load_GFPGAN():
model_name = 'GFPGANv1.3'
model_path = os.path.join(st.session_state['defaults'].general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
@ -673,6 +694,7 @@ def load_GFPGAN():
instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
return instance
@retry(tries=5)
def load_RealESRGAN(model_name: str):
from basicsr.archs.rrdbnet_arch import RRDBNet
RealESRGAN_models = {
@ -700,6 +722,7 @@ def load_RealESRGAN(model_name: str):
return instance
#
@retry(tries=5)
def load_LDSR(checking=False):
model_name = 'model'
yaml_name = 'project'
@ -719,6 +742,8 @@ def load_LDSR(checking=False):
#
LDSR = None
@retry(tries=5)
def try_loading_LDSR(model_name: str,checking=False):
global LDSR
if os.path.exists(st.session_state['defaults'].general.LDSR_dir):
@ -739,8 +764,10 @@ def try_loading_LDSR(model_name: str,checking=False):
# Loads Stable Diffusion model by name
#@retry(tries=5)
def load_sd_model(model_name: str) -> [any, any, any, any, any]:
ckpt_path = st.session_state.defaults.general.default_model_path
if model_name != st.session_state.defaults.general.default_model:
ckpt_path = os.path.join("models", "custom", f"{model_name}.ckpt")
@ -864,12 +891,12 @@ def generation_callback(img, i=0):
# It can probably be done in a better way for someone who knows what they're doing. I don't.
#print (img,isinstance(img, torch.Tensor))
if isinstance(img, torch.Tensor):
x_samples_ddim = (st.session_state["model"].to('cuda') if not st.session_state['defaults'].general.optimized else st.session_state.modelFS.to('cuda')
x_samples_ddim = (server_state["model"].to('cuda') if not st.session_state['defaults'].general.optimized else server_state["modelFS"].to('cuda')
).decode_first_stage(img).to('cuda')
else:
# When using the k Diffusion samplers they return a dict instead of a tensor that look like this:
# {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}
x_samples_ddim = (st.session_state["model"].to('cuda') if not st.session_state['defaults'].general.optimized else st.session_state.modelFS.to('cuda')
x_samples_ddim = (server_state["model"].to('cuda') if not st.session_state['defaults'].general.optimized else server_state["modelFS"].to('cuda')
).decode_first_stage(img["denoised"]).to('cuda')
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@ -954,6 +981,7 @@ def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995):
return v2
#
@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list):
"""Find the optimal update_preview_frequency value maximizing
performance while minimizing the time between updates."""
@ -989,8 +1017,8 @@ def get_font(fontsize):
raise Exception(f"No usable font found (tried {', '.join(fonts)})")
def load_embeddings(fp):
if fp is not None and hasattr(st.session_state["model"], "embedding_manager"):
st.session_state["model"].embedding_manager.load(fp['name'])
if fp is not None and hasattr(server_state["model"], "embedding_manager"):
server_state["model"].embedding_manager.load(fp['name'])
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
@ -1148,10 +1176,10 @@ def enable_minimal_memory_usage(model):
def check_prompt_length(prompt, comments):
"""this function tests if prompt is too long, and if so, adds a message to comments"""
tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer
max_length = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.max_length
tokenizer = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).cond_stage_model.tokenizer
max_length = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).cond_stage_model.max_length
info = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
info = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
ovf = info['overflowing_tokens'][0]
overflowing_count = ovf.shape[0]
@ -1169,17 +1197,17 @@ def custom_models_available():
#
# Allow for custom models to be used instead of the default one,
# an example would be Waifu-Diffusion or any other fine tune of stable diffusion
st.session_state["custom_models"]:sorted = []
server_state["custom_models"]:sorted = []
for root, dirs, files in os.walk(os.path.join("models", "custom")):
for file in files:
if os.path.splitext(file)[1] == '.ckpt':
st.session_state["custom_models"].append(os.path.splitext(file)[0])
server_state["custom_models"].append(os.path.splitext(file)[0])
if len(st.session_state["custom_models"]) > 0:
if len(server_state["custom_models"]) > 0:
st.session_state["CustomModel_available"] = True
st.session_state["custom_models"].append("Stable Diffusion v1.4")
server_state["custom_models"].append("Stable Diffusion v1.4")
else:
st.session_state["CustomModel_available"] = False
@ -1217,7 +1245,7 @@ def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, widt
target="txt2img" if init_img is None else "img2img",
prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name,
ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale,
seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights, model_name=st.session_state["loaded_model"])
seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights, model_name=server_state["loaded_model"])
# Not yet any use for these, but they bloat up the files:
# info_dict["init_img"] = init_img
# info_dict["init_mask"] = init_mask
@ -1386,8 +1414,8 @@ def process_images(
if prompt_tokens:
# compviz
tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer
text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer
tokenizer = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).cond_stage_model.tokenizer
text_encoder = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).cond_stage_model.transformer
# diffusers
#tokenizer = pipe.tokenizer
@ -1471,7 +1499,7 @@ def process_images(
output_images = []
grid_captions = []
stats = []
with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not st.session_state['defaults'].general.optimized else nullcontext()):
with torch.no_grad(), precision_scope("cuda"), (server_state["model"].ema_scope() if not st.session_state['defaults'].general.optimized else nullcontext()):
init_data = func_init()
tic = time.time()
@ -1497,9 +1525,9 @@ def process_images(
print(prompt)
if st.session_state['defaults'].general.optimized:
st.session_state.modelCS.to(st.session_state['defaults'].general.gpu)
server_state["modelCS"].to(st.session_state['defaults'].general.gpu)
uc = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(len(prompts) * [negprompt])
uc = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).get_learned_conditioning(len(prompts) * [negprompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -1513,23 +1541,23 @@ def process_images(
c = torch.zeros_like(uc) # i dont know if this is correct.. but it works
for i in range(0, len(weighted_subprompts)):
# note if alpha negative, it functions same as torch.sub
c = torch.add(c, (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1])
c = torch.add(c, (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1])
else: # just behave like usual
c = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(prompts)
c = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).get_learned_conditioning(prompts)
shape = [opt_C, height // opt_f, width // opt_f]
if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6
st.session_state.modelCS.to("cpu")
server_state["modelCS"].to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
if noise_mode == 1 or noise_mode == 3:
# TODO params for find_noise_to_image
x = torch.cat(batch_size * [find_noise_for_image(
st.session_state["model"], st.session_state["device"],
server_state["model"], server_state["device"],
init_img.convert('RGB'), '', find_noise_steps, 0.0, normalize=True,
generation_callback=generation_callback,
)], dim=0)
@ -1551,9 +1579,9 @@ def process_images(
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
if st.session_state['defaults'].general.optimized:
st.session_state.modelFS.to(st.session_state['defaults'].general.gpu)
server_state["modelFS"].to(st.session_state['defaults'].general.gpu)
x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(samples_ddim)
x_samples_ddim = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
run_images = []
@ -1567,7 +1595,7 @@ def process_images(
full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt)
sanitized_prompt = sanitized_prompt[:220-len(full_path)]
sanitized_prompt = sanitized_prompt[:200-len(full_path)]
sample_path_i = os.path.join(sample_path, sanitized_prompt)
#print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}")
@ -1580,7 +1608,7 @@ def process_images(
full_path = os.path.join(os.getcwd(), sample_path)
sample_path_i = sample_path
base_count = get_next_sequence_number(sample_path_i)
filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before
filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)] #same as before
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
@ -1590,11 +1618,11 @@ def process_images(
st.session_state["preview_image"].image(image)
if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN:
if use_GFPGAN and server_state["GFPGAN"] is not None and not use_RealESRGAN:
st.session_state["progress_bar_text"].text("Running GFPGAN on image %d of %d..." % (i+1, len(x_samples_ddim)))
#skip_save = True # #287 >_>
torch_gc()
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample)
gfpgan_filename = original_filename + '-gfpgan'
@ -1602,7 +1630,7 @@ def process_images(
save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback,
uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta,
n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"])
n_iter, batch_size, i, denoising_strength, resize_mode, False, server_state["loaded_model"])
output_images.append(gfpgan_image) #287
run_images.append(gfpgan_image)
@ -1610,16 +1638,16 @@ def process_images(
if simple_templating:
grid_captions.append( captions[i] + "\ngfpgan" )
elif use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN:
elif use_RealESRGAN and server_state["RealESRGAN"] is not None and not use_GFPGAN:
st.session_state["progress_bar_text"].text("Running RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim)))
#skip_save = True # #287 >_>
torch_gc()
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
if server_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1])
output, img_mode = server_state["RealESRGAN"].enhance(x_sample[:,:,::-1])
esrgan_filename = original_filename + '-esrgan4x'
esrgan_sample = output[:,:,::-1]
esrgan_image = Image.fromarray(esrgan_sample)
@ -1630,7 +1658,7 @@ def process_images(
save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"])
save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, server_state["loaded_model"])
output_images.append(esrgan_image) #287
run_images.append(esrgan_image)
@ -1638,25 +1666,25 @@ def process_images(
if simple_templating:
grid_captions.append( captions[i] + "\nesrgan" )
elif use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None:
elif use_RealESRGAN and server_state["RealESRGAN"] is not None and use_GFPGAN and server_state["GFPGAN"] is not None:
st.session_state["progress_bar_text"].text("Running GFPGAN+RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim)))
#skip_save = True # #287 >_>
torch_gc()
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
if server_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
output, img_mode = server_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
gfpgan_esrgan_sample = output[:,:,::-1]
gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"])
save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, server_state["loaded_model"])
output_images.append(gfpgan_esrgan_image) #287
run_images.append(gfpgan_esrgan_image)
@ -1674,16 +1702,16 @@ def process_images(
init_img = init_img.convert('RGB')
image = image.convert('RGB')
if use_RealESRGAN and st.session_state["RealESRGAN"] is not None:
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
if use_RealESRGAN and server_state["RealESRGAN"] is not None:
if server_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8))
output, img_mode = server_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')
output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8))
output, img_mode = server_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')
@ -1692,7 +1720,7 @@ def process_images(
if save_individual_images:
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, st.session_state["loaded_model"])
save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, server_state["loaded_model"])
#if add_original_image or not simple_templating:
#output_images.append(image)
@ -1701,7 +1729,7 @@ def process_images(
if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6
st.session_state.modelFS.to("cpu")
server_state["modelFS"].to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
@ -1735,7 +1763,7 @@ def process_images(
output_images.insert(0, grid)
grid_count = get_next_sequence_number(outpath, 'grid-')
grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}"
grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:200-len(full_path)])}.{grid_ext}"
grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True)
toc = time.time()
@ -1745,7 +1773,7 @@ def process_images(
info = f"""
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and server_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and server_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
stats = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''

View File

@ -1,29 +1,49 @@
import gc
import inspect
import warnings
from tqdm.auto import tqdm
from typing import List, Optional, Union
import torch
from diffusers import ModelMixin
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker
from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler,
PNDMScheduler)
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers import StableDiffusionPipelineOutput
#from diffusers.safety_checker import StableDiffusionSafetyChecker
class StableDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
):
super().__init__()
scheduler = scheduler.set_format("pt")
@ -37,10 +57,45 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_minimal_memory_usage(self):
"""Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs"""
self.unet.to(torch.float16).to(torch.device("cuda"))
self.enable_attention_slicing(1)
torch.cuda.empty_cache()
gc.collect()
@torch.no_grad()
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
prompt: Union[str, List[str]],
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
@ -48,38 +103,75 @@ class StableDiffusionPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
text_embeddings: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
# device = kwargs.pop("torch_device")
warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and"
" will be removed in v0.3.0. Consider using `pipe.to(torch_device)`"
" instead."
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" Consider using `pipe.to(torch_device)` instead."
)
# Set device as before (to be removed in 0.3.0)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
# if device is None:
# device = "cuda" if torch.cuda.is_available() else "cpu"
# self.to(device)
if text_embeddings is None:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
"`height` and `width` have to be divisible by 8 but are"
f" {height} and {width}."
)
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_input = self.tokenizer(
@ -89,9 +181,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
else:
batch_size = text_embeddings.shape[0]
text_embeddings = self.text_encoder(text_input.input_ids.to(self.text_encoder.device))[0].to(self.unet.device)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@ -99,17 +189,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
# max_length = text_input.input_ids.shape[-1]
max_length = 77 # self.tokenizer.model_max_length
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.text_encoder.device))[0].to(
self.unet.device
)
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(self.device)
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@ -117,25 +203,25 @@ class StableDiffusionPipeline(DiffusionPipeline):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
device=latents_device,
)
else:
if latents.shape != latents_shape:
raise ValueError(
f"Unexpected latents shape, got {latents.shape}, expected"
f" {latents_shape}"
)
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
# set timesteps
accepts_offset = "offset" in set(
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
)
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
@ -150,18 +236,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
@ -169,65 +251,43 @@ class StableDiffusionPipeline(DiffusionPipeline):
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings
)["sample"]
latent_model_input.to(self.unet.device), t.to(self.unet.device), encoder_hidden_states=text_embeddings
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(
noise_pred, i, latents, **extra_step_kwargs
)["prev_sample"]
noise_pred, i, latents.to(self.unet.device), **extra_step_kwargs
).prev_sample
else:
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
)["prev_sample"]
noise_pred, t.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs
).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = self.vae.decode(latents.to(self.vae.device)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = image.to(self.vae.device).to(self.vae.device).cpu().permute(0, 2, 3, 1).numpy()
safety_cheker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="pt"
).to(self.device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_cheker_input.pixel_values
# run safety checker
safety_cheker_input = (
self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt")
.to(self.vae.device)
.to(self.vae.dtype)
)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
if not return_dict:
return (image, has_nsfw_concept)
def embed_text(self, text):
"""Helper to embed some text"""
with torch.autocast("cuda"):
text_input = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
return embed
class NoCheck(ModelMixin):
"""Can be used in place of safety checker. Use responsibly and at your own risk."""
def __init__(self):
super().__init__()
self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3)))
def forward(self, images=None, **kwargs):
return images, [False]
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -36,14 +36,14 @@ class plugin_info():
if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
GFPGAN_available = True
server_state["GFPGAN_available"] = True
else:
GFPGAN_available = False
server_state["GFPGAN_available"] = False
if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")):
RealESRGAN_available = True
server_state["RealESRGAN_available"] = True
else:
RealESRGAN_available = False
server_state["RealESRGAN_available"] = False
#
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str,
@ -69,21 +69,21 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_na
#use_RealESRGAN = 8 in toggles
if sampler_name == 'PLMS':
sampler = PLMSSampler(st.session_state["model"])
sampler = PLMSSampler(server_state["model"])
elif sampler_name == 'DDIM':
sampler = DDIMSampler(st.session_state["model"])
sampler = DDIMSampler(server_state["model"])
elif sampler_name == 'k_dpm_2_a':
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral')
elif sampler_name == 'k_dpm_2':
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
sampler = KDiffusionSampler(server_state["model"],'dpm_2')
elif sampler_name == 'k_euler_a':
sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
sampler = KDiffusionSampler(server_state["model"],'euler_ancestral')
elif sampler_name == 'k_euler':
sampler = KDiffusionSampler(st.session_state["model"],'euler')
sampler = KDiffusionSampler(server_state["model"],'euler')
elif sampler_name == 'k_heun':
sampler = KDiffusionSampler(st.session_state["model"],'heun')
sampler = KDiffusionSampler(server_state["model"],'heun')
elif sampler_name == 'k_lms':
sampler = KDiffusionSampler(st.session_state["model"],'lms')
sampler = KDiffusionSampler(server_state["model"],'lms')
else:
raise Exception("Unknown sampler: " + sampler_name)
@ -209,8 +209,8 @@ def layout():
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
custom_models_available()
if st.session_state.CustomModel_available:
st.session_state.custom_model = st.selectbox("Custom Model:", st.session_state.custom_models,
index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model),
server_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"],
index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
help="Select the model you want to use. This option is only available if you have custom models \
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
@ -243,13 +243,13 @@ def layout():
write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files, help="Save a file next to the image with informartion about the generation.")
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.")
if st.session_state["GFPGAN_available"]:
if server_state["GFPGAN_available"]:
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
else:
st.session_state["use_GFPGAN"] = False
if st.session_state["RealESRGAN_available"]:
if server_state["RealESRGAN_available"]:
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2img.use_RealESRGAN,
help="Uses the RealESRGAN model to upscale the images after the generation.\
This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")

View File

@ -6,6 +6,9 @@ from sd_utils import *
from streamlit import StopException
from streamlit.elements import image as STImage
#streamlit components section
from streamlit_server_state import server_state, server_state_lock
#other imports
import os
@ -19,11 +22,12 @@ from io import BytesIO
import imageio
from slugify import slugify
# Temp imports
from diffusers import StableDiffusionPipeline
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
PNDMScheduler
# Temp imports
# end of imports
#---------------------------------------------------------------------------------------------------------------
@ -201,6 +205,66 @@ def diffuse(
return image2
#
@st.experimental_singleton(show_spinner=False, suppress_st_warning=True)
def load_diffusers_model(weights_path,torch_device):
with server_state_lock["model"]:
if "model" in server_state:
del server_state["model"]
with server_state_lock["pipe"]:
try:
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
if st.session_state["weights_path"] != weights_path:
del st.session_state["weights_path"]
st.session_state["weights_path"] = weights_path
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
server_state["pipe"].unet.to(torch_device)
server_state["pipe"].vae.to(torch_device)
server_state["pipe"].text_encoder.to(torch_device)
if st.session_state.defaults.general.enable_attention_slicing:
server_state["pipe"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage:
server_state["pipe"].enable_minimal_memory_usage()
print("Tx2Vid Model Loaded")
else:
print("Tx2Vid Model already Loaded")
except:
#del st.session_state["weights_path"]
#del server_state["pipe"]
st.session_state["weights_path"] = weights_path
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
server_state["pipe"].unet.to(torch_device)
server_state["pipe"].vae.to(torch_device)
server_state["pipe"].text_encoder.to(torch_device)
if st.session_state.defaults.general.enable_attention_slicing:
server_state["pipe"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage:
server_state["pipe"].enable_minimal_memory_usage()
print("Tx2Vid Model Loaded")
#
def txt2vid(
# --------------------------------------
# args you probably want to change
@ -337,59 +401,12 @@ def txt2vid(
#print (st.session_state["weights_path"] != weights_path)
try:
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
if st.session_state["weights_path"] != weights_path:
del st.session_state["weights_path"]
load_diffusers_model(weights_path, torch_device)
st.session_state["weights_path"] = weights_path
st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
server_state["pipe"].scheduler = SCHEDULERS[scheduler]
st.session_state["pipe"].unet.to(torch_device)
st.session_state["pipe"].vae.to(torch_device)
st.session_state["pipe"].text_encoder.to(torch_device)
if st.session_state.defaults.general.enable_attention_slicing:
st.session_state["pipe"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage:
st.session_state["pipe"].enable_minimal_memory_usage()
print("Tx2Vid Model Loaded")
else:
print("Tx2Vid Model already Loaded")
except:
#del st.session_state["weights_path"]
#del st.session_state["pipe"]
st.session_state["weights_path"] = weights_path
st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
st.session_state["pipe"].unet.to(torch_device)
st.session_state["pipe"].vae.to(torch_device)
st.session_state["pipe"].text_encoder.to(torch_device)
if st.session_state.defaults.general.enable_attention_slicing:
st.session_state["pipe"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage:
st.session_state["pipe"].enable_minimal_memory_usage()
print("Tx2Vid Model Loaded")
st.session_state["pipe"].scheduler = SCHEDULERS[scheduler]
server_state["pipe"].use_multiprocessing_for_evaluation = False
server_state["pipe"].use_multiprocessed_decoding = False
if do_loop:
prompts = str([prompts, prompts])
@ -399,8 +416,8 @@ def txt2vid(
#seeds.append(first_seed)
# get the conditional text embeddings based on the prompt
text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt")
cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768]
text_input = server_state["pipe"].tokenizer(prompts, padding="max_length", max_length=server_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt")
cond_embeddings = server_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768]
#
if st.session_state.defaults.general.use_sd_concepts_library:
@ -434,7 +451,7 @@ def txt2vid(
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>")
# sample a source
init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
init1 = torch.randn((1, server_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
# iterate the loop
@ -451,7 +468,7 @@ def txt2vid(
st.session_state["current_frame"] = frame_index
# sample the destination
init2 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
init2 = torch.randn((1, server_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
for i, t in enumerate(np.linspace(0, 1, num_steps)):
start = timeit.default_timer()
@ -465,9 +482,9 @@ def txt2vid(
init = slerp(gpu, float(t), init1, init2)
with autocast("cuda"):
image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)
image = diffuse(server_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)
if st.session_state["save_individual_images"] and not st.session_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
if st.session_state["save_individual_images"] and not server_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
#im = Image.fromarray(image)
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
image.save(outpath, quality=quality)
@ -481,13 +498,13 @@ def txt2vid(
#
#try:
#if st.session_state["use_GFPGAN"] and st.session_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
if st.session_state["use_GFPGAN"] and st.session_state["GFPGAN"] is not None:
#if server_state["use_GFPGAN"] and server_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
if server_state["use_GFPGAN"] and server_state["GFPGAN"] is not None:
#print("Running GFPGAN on image ...")
st.session_state["progress_bar_text"].text("Running GFPGAN on image ...")
#skip_save = True # #287 >_>
torch_gc()
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(np.array(image)[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(np.array(image)[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample)
@ -698,9 +715,9 @@ def layout():
st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")
if GFPGAN_available:
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
server_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
else:
st.session_state["use_GFPGAN"] = False
server_state["use_GFPGAN"] = False
if RealESRGAN_available:
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN,
@ -726,16 +743,16 @@ def layout():
if generate_button:
#print("Loading models")
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
#load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
#load_models(False, server_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
if st.session_state["use_GFPGAN"]:
if server_state["use_GFPGAN"]:
if "GFPGAN" in st.session_state:
print("GFPGAN already loaded")
else:
# Load GFPGAN
if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
try:
st.session_state["GFPGAN"] = load_GFPGAN()
server_state["GFPGAN"] = load_GFPGAN()
print("Loaded GFPGAN")
except Exception:
import traceback
@ -743,9 +760,9 @@ def layout():
print(traceback.format_exc(), file=sys.stderr)
else:
if "GFPGAN" in st.session_state:
del st.session_state["GFPGAN"]
del server_state["GFPGAN"]
#try:
try:
# run video generation
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
@ -801,7 +818,7 @@ def layout():
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
#except (StopException, KeyError):
#print(f"Received Streamlit StopException")
except (StopException, KeyError):
print(f"Received Streamlit StopException")

View File

@ -7,6 +7,7 @@ import streamlit_nested_layout
#streamlit components section
from st_on_hover_tabs import on_hover_tabs
from streamlit_server_state import server_state, server_state_lock
#other imports
@ -41,6 +42,7 @@ except:
# remove some annoying deprecation warnings that show every now and then.
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# this should force GFPGAN and RealESRGAN onto the selected gpu as well
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
@ -70,15 +72,17 @@ def layout():
load_css(True, 'frontend/css/streamlit.main.css')
# check if the models exist on their respective folders
with server_state_lock["GFPGAN_available"]:
if os.path.exists(os.path.join(st.session_state["defaults"].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
st.session_state["GFPGAN_available"] = True
server_state["GFPGAN_available"] = True
else:
st.session_state["GFPGAN_available"] = False
server_state["GFPGAN_available"] = False
with server_state_lock["RealESRGAN_available"]:
if os.path.exists(os.path.join(st.session_state["defaults"].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")):
st.session_state["RealESRGAN_available"] = True
server_state["RealESRGAN_available"] = True
else:
st.session_state["RealESRGAN_available"] = False
server_state["RealESRGAN_available"] = False
## Allow for custom models to be used instead of the default one,
## an example would be Waifu-Diffusion or any other fine tune of stable diffusion