Moved the models to use the server_state component instead of session_state so it can be shared between multiple sessions, tabs and users as long as the streamlit server is running.

Moved the models to use the server_state component instead of session_state so it can be shared between multiple sessions, tabs and users as long as the streamlit server is running.
This commit is contained in:
ZeroCool 2022-09-25 04:03:48 -07:00 committed by GitHub
commit f4c8b9500f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 744 additions and 465 deletions

View File

@ -273,8 +273,13 @@ img2img:
variant_seed: "" variant_seed: ""
write_info_files: True write_info_files: True
concepts_library:
concepts_per_page: 12
gfpgan: gfpgan:
strength: 100 strength: 100
textual_inversion: textual_inversion:
value: 0 value: 0

View File

@ -42,6 +42,7 @@ dependencies:
- streamlit-on-Hover-tabs==1.0.1 - streamlit-on-Hover-tabs==1.0.1
- streamlit-option-menu==0.3.2 - streamlit-option-menu==0.3.2
- streamlit_nested_layout - streamlit_nested_layout
- streamlit-server-state==0.14.2
- test-tube>=0.7.5 - test-tube>=0.7.5
- tensorboard - tensorboard
- torch-fidelity==0.3.0 - torch-fidelity==0.3.0

View File

@ -6,6 +6,7 @@ from sd_utils import *
#streamlit components section #streamlit components section
import streamlit_nested_layout import streamlit_nested_layout
from streamlit_server_state import server_state, server_state_lock
#other imports #other imports
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -17,9 +18,11 @@ def layout():
st.header("Settings") st.header("Settings")
with st.form("Settings"): with st.form("Settings"):
general_tab, txt2img_tab, img2img_tab, txt2vid_tab, textual_inversion_tab = st.tabs(['General', "Text-To-Image", general_tab, txt2img_tab, img2img_tab, \
txt2vid_tab, textual_inversion_tab, concepts_library_tab = st.tabs(['General', "Text-To-Image",
"Image-To-Image", "Text-To-Video", "Image-To-Image", "Text-To-Video",
"Textual Inversion"]) "Textual Inversion",
"Concepts Library"])
with general_tab: with general_tab:
col1, col2, col3, col4, col5 = st.columns(5, gap='large') col1, col2, col3, col4, col5 = st.columns(5, gap='large')
@ -47,8 +50,8 @@ def layout():
custom_models_available() custom_models_available()
if st.session_state.CustomModel_available: if st.session_state.CustomModel_available:
st.session_state.default_model = st.selectbox("Default Model:", st.session_state.custom_models, st.session_state.default_model = st.selectbox("Default Model:", server_state["custom_models"],
index=st.session_state.custom_models.index(st.session_state['defaults'].general.default_model), index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
help="Select the model you want to use. If you have placed custom models \ help="Select the model you want to use. If you have placed custom models \
on your 'models/custom' folder they will be shown here as well. The model name that will be shown here \ on your 'models/custom' folder they will be shown here as well. The model name that will be shown here \
is the same as the name the file for the model has on said folder, \ is the same as the name the file for the model has on said folder, \
@ -197,6 +200,14 @@ def layout():
st.title("Textual Inversion") st.title("Textual Inversion")
st.info("Under Construction. :construction_worker:") st.info("Under Construction. :construction_worker:")
with concepts_library_tab:
st.title("Concepts Library")
#st.info("Under Construction. :construction_worker:")
col1, col2, col3, col4, col5 = st.columns(5, gap='large')
with col1:
st.session_state["defaults"].concepts_library.concepts_per_page = int(st.text_input("Concepts Per Page", value=st.session_state['defaults'].concepts_library.concepts_per_page,
help="Number of concepts per page to show on the Concepts Library. Default: '12'"))
# add space for the buttons at the bottom # add space for the buttons at the bottom
st.markdown("---") st.markdown("---")

View File

@ -28,8 +28,6 @@ from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from slugify import slugify from slugify import slugify
import json import json
import os
import sys
logger = get_logger(__name__) logger = get_logger(__name__)
@ -40,7 +38,6 @@ def parse_args():
"--pretrained_model_name_or_path", "--pretrained_model_name_or_path",
type=str, type=str,
default=None, default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument( parser.add_argument(
@ -50,17 +47,16 @@ def parse_args():
help="Pretrained tokenizer name or path if not the same as model_name", help="Pretrained tokenizer name or path if not the same as model_name",
) )
parser.add_argument( parser.add_argument(
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." "--train_data_dir", type=str, default=None, help="A folder containing the training data."
) )
parser.add_argument( parser.add_argument(
"--placeholder_token", "--placeholder_token",
type=str, type=str,
default=None, default=None,
required=True,
help="A token to use as a placeholder for the concept.", help="A token to use as a placeholder for the concept.",
) )
parser.add_argument( parser.add_argument(
"--initializer_token", type=str, default=None, required=True,help="A token to use as initializer word." "--initializer_token", type=str, default=None, help="A token to use as initializer word."
) )
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
@ -68,7 +64,6 @@ def parse_args():
"--output_dir", "--output_dir",
type=str, type=str,
default="text-inversion-model", default="text-inversion-model",
required=True,
help="The output directory where the model predictions and checkpoints will be written.", help="The output directory where the model predictions and checkpoints will be written.",
) )
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@ -128,6 +123,31 @@ def parse_args():
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument(
"--use_auth_token",
action="store_true",
help=(
"Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
" private models)."
),
)
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
type=str, type=str,
@ -144,14 +164,32 @@ def parse_args():
"--checkpoint_frequency", "--checkpoint_frequency",
type=int, type=int,
default=500, default=500,
help="How often to save a checkpoint", help="How often to save a checkpoint and sample image",
)
parser.add_argument(
"--stable_sample_batches",
type=int,
default=0,
help="Number of fixed seed sample batches to generate per checkpoint",
)
parser.add_argument(
"--random_sample_batches",
type=int,
default=1,
help="Number of random seed sample batches to generate per checkpoint",
)
parser.add_argument(
"--sample_batch_size",
type=int,
default=1,
help="Number of samples to generate per batch",
) )
parser.add_argument( parser.add_argument(
"--custom_templates", "--custom_templates",
type=str, type=str,
default=None, default=None,
help=( help=(
"A comma-delimited list of custom templates to use" "A comma-delimited list of custom template to use for samples, using {} as a placeholder for the concept."
), ),
) )
parser.add_argument( parser.add_argument(
@ -166,19 +204,10 @@ def parse_args():
default=None, default=None,
help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)." help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
) )
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to a JSON config file specifying the arguments to use. If resume_from is given, it is automatically inferred."
)
args = parser.parse_args() args = parser.parse_args()
if args.config is not None: if args.resume_from is not None:
with open(args.config, 'rt') as f: with open(Path(args.resume_from) / "resume.json", 'rt') as f:
args = parser.parse_args(namespace=argparse.Namespace(**json.load(f)))
elif args.resume_from is not None:
with open(f"{args.resume_from}/resume.json", 'rt') as f:
args = parser.parse_args(namespace=argparse.Namespace(**json.load(f)["args"])) args = parser.parse_args(namespace=argparse.Namespace(**json.load(f)["args"]))
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@ -277,10 +306,10 @@ class TextualInversionDataset(Dataset):
self._length = self.num_images * repeats self._length = self.num_images * repeats
self.interpolation = { self.interpolation = {
"linear": PIL.Image.LINEAR, "linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.BILINEAR, "bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.BICUBIC, "bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.LANCZOS, "lanczos": PIL.Image.Resampling.LANCZOS,
}[interpolation] }[interpolation]
self.templates = templates self.templates = templates
@ -329,6 +358,16 @@ class TextualInversionDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def freeze_params(params): def freeze_params(params):
for param in params: for param in params:
param.requires_grad = False param.requires_grad = False
@ -337,7 +376,7 @@ def freeze_params(params):
def save_resume_file(basepath, args, extra = {}): def save_resume_file(basepath, args, extra = {}):
info = {"args": vars(args)} info = {"args": vars(args)}
info["args"].update(extra) info["args"].update(extra)
with open(f"{basepath}/resume.json", "w") as f: with open(Path(basepath) / "resume.json", "w") as f:
json.dump(info, f, indent=4) json.dump(info, f, indent=4)
@ -352,6 +391,9 @@ class Checkpointer:
placeholder_token_id, placeholder_token_id,
templates, templates,
output_dir, output_dir,
random_sample_batches,
sample_batch_size,
stable_sample_batches,
seed seed
): ):
self.accelerator = accelerator self.accelerator = accelerator
@ -362,14 +404,17 @@ class Checkpointer:
self.placeholder_token_id = placeholder_token_id self.placeholder_token_id = placeholder_token_id
self.templates = templates self.templates = templates
self.output_dir = output_dir self.output_dir = output_dir
self.random_sample_batches = random_sample_batches
self.sample_batch_size = sample_batch_size
self.stable_sample_batches = stable_sample_batches
self.seed = seed self.seed = seed
def checkpoint(self, step, text_encoder): def checkpoint(self, step, text_encoder, save_samples=True):
print("Saving checkpoint for step %d..." % step) print("Saving checkpoint for step %d..." % step)
with torch.autocast("cuda"): with torch.autocast("cuda"):
checkpoints_path = f"{self.output_dir}/checkpoints" checkpoints_path = self.output_dir / "checkpoints"
os.makedirs(checkpoints_path, exist_ok=True) checkpoints_path.mkdir(exist_ok=True, parents=True)
unwrapped = self.accelerator.unwrap_model(text_encoder) unwrapped = self.accelerator.unwrap_model(text_encoder)
@ -378,27 +423,94 @@ class Checkpointer:
learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
filename = f"learned_embeds_%s_%d.bin" % (slugify(self.placeholder_token), step) filename = f"learned_embeds_%s_%d.bin" % (slugify(self.placeholder_token), step)
torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") torch.save(learned_embeds_dict, checkpoints_path / filename)
torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") torch.save(learned_embeds_dict, checkpoints_path / "last.bin")
del unwrapped del unwrapped
return f"{checkpoints_path}/last.bin" return checkpoints_path / "last.bin"
def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
samples_path = self.output_dir / "samples"
samples_path.mkdir(exist_ok=True, parents=True)
checker = NoCheck()
with torch.autocast("cuda"):
unwrapped = self.accelerator.unwrap_model(text_encoder)
# Save a sample image
pipeline = StableDiffusionPipeline(
text_encoder=unwrapped,
vae=self.vae,
unet=self.unet,
tokenizer=self.tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=NoCheck(),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
).to('cuda')
pipeline.enable_attention_slicing()
if self.stable_sample_batches > 0:
stable_latents = torch.randn(
(self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
device=pipeline.device,
generator=torch.Generator(device=pipeline.device).manual_seed(self.seed),
)
stable_prompts = [choice.format(self.placeholder_token) for choice in (self.templates * self.sample_batch_size)[:self.sample_batch_size]]
# Generate and save stable samples
for i in range(0, self.stable_sample_batches):
samples = pipeline(
prompt=stable_prompts,
height=max(512, height),
latents=stable_latents,
width=max(512, width),
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type='pil'
)["sample"]
for idx, im in enumerate(samples):
filename = f"stable_sample_%d_%d_step_%d.png" % (i+1, idx+1, step)
im.save(samples_path / filename)
prompts = [choice.format(self.placeholder_token) for choice in random.choices(self.templates, k=self.sample_batch_size)]
# Generate and save random samples
for i in range(0, self.random_sample_batches):
samples = pipeline(
prompt=prompts,
height=max(512, height),
width=max(512, width),
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type='pil'
)["sample"]
for idx, im in enumerate(samples):
filename = f"step_%d_sample_%d_%d.png" % (step, i+1, idx+1)
im.save(samples_path / filename)
del im
del pipeline
del unwrapped
def main(): def main():
args = parse_args() args = parse_args()
global_step_offset = 0 global_step_offset = 0
if args.resume_from is not None: if args.resume_from is not None:
basepath = f"{args.resume_from}" basepath = Path(args.resume_from)
print("Resuming state from %s" % args.resume_from) print("Resuming state from %s" % args.resume_from)
with open(f"{basepath}/resume.json", 'r') as f: with open(basepath / "resume.json", 'r') as f:
state = json.load(f) state = json.load(f)
global_step_offset = state["args"]["global_step"] global_step_offset = state["args"]["global_step"]
print("We've trained %d steps so far" % global_step_offset) print("We've trained %d steps so far" % global_step_offset)
else: else:
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") basepath = Path(args.logging_dir) / slugify(args.placeholder_token)
basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}" basepath.mkdir(exist_ok=True, parents=True)
os.makedirs(basepath, exist_ok=True)
accelerator = Accelerator( accelerator = Accelerator(
@ -410,6 +522,23 @@ def main():
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer and add the placeholder token as a additional special token # Load the tokenizer and add the placeholder token as a additional special token
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
@ -487,6 +616,9 @@ def main():
placeholder_token_id=placeholder_token_id, placeholder_token_id=placeholder_token_id,
templates=templates, templates=templates,
output_dir=basepath, output_dir=basepath,
sample_batch_size=args.sample_batch_size,
random_sample_batches=args.random_sample_batches,
stable_sample_batches=args.stable_sample_batches,
seed=args.seed seed=args.seed
) )
@ -518,7 +650,7 @@ def main():
learnable_property=args.learnable_property, learnable_property=args.learnable_property,
center_crop=args.center_crop, center_crop=args.center_crop,
set="train", set="train",
templates=templates templates=base_templates
) )
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
@ -628,6 +760,8 @@ def main():
"global_step": global_step + global_step_offset, "global_step": global_step + global_step_offset,
"resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin") "resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin")
}) })
checkpointer.save_samples(global_step + global_step_offset, text_encoder,
args.resolution, args.resolution, 7.5, 0.0, 25)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
@ -640,16 +774,36 @@ def main():
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process: if accelerator.is_main_process:
pipeline = StableDiffusionPipeline(
text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
#pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, f"{basepath}/learned_embeds.bin") torch.save(learned_embeds_dict, basepath / f"learned_embeds.bin")
if global_step % args.checkpoint_frequency != 0:
checkpointer.save_samples(global_step + global_step_offset, text_encoder,
args.resolution, args.resolution, 7.5, 0.0, 25)
print("Saving resume state") print("Saving resume state")
save_resume_file(basepath, args, { save_resume_file(basepath, args, {
"global_step": global_step + global_step_offset, "global_step": global_step + global_step_offset,
"resume_checkpoint": f"{basepath}/checkpoints/last.bin" "resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin")
}) })
if args.push_to_hub:
repo.push_to_hub(
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training() accelerator.end_training()
except KeyboardInterrupt: except KeyboardInterrupt:
@ -658,7 +812,7 @@ def main():
checkpointer.checkpoint(global_step + global_step_offset, text_encoder) checkpointer.checkpoint(global_step + global_step_offset, text_encoder)
save_resume_file(basepath, args, { save_resume_file(basepath, args, {
"global_step": global_step + global_step_offset, "global_step": global_step + global_step_offset,
"resume_checkpoint": f"{basepath}/checkpoints/last.bin" "resume_checkpoint": str(Path(basepath) / "checkpoints" / "last.bin")
}) })
quit() quit()

View File

@ -65,21 +65,21 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
#use_RealESRGAN = 11 in toggles #use_RealESRGAN = 11 in toggles
if sampler_name == 'PLMS': if sampler_name == 'PLMS':
sampler = PLMSSampler(st.session_state["model"]) sampler = PLMSSampler(server_state["model"])
elif sampler_name == 'DDIM': elif sampler_name == 'DDIM':
sampler = DDIMSampler(st.session_state["model"]) sampler = DDIMSampler(server_state["model"])
elif sampler_name == 'k_dpm_2_a': elif sampler_name == 'k_dpm_2_a':
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral')
elif sampler_name == 'k_dpm_2': elif sampler_name == 'k_dpm_2':
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') sampler = KDiffusionSampler(server_state["model"],'dpm_2')
elif sampler_name == 'k_euler_a': elif sampler_name == 'k_euler_a':
sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') sampler = KDiffusionSampler(server_state["model"],'euler_ancestral')
elif sampler_name == 'k_euler': elif sampler_name == 'k_euler':
sampler = KDiffusionSampler(st.session_state["model"],'euler') sampler = KDiffusionSampler(server_state["model"],'euler')
elif sampler_name == 'k_heun': elif sampler_name == 'k_heun':
sampler = KDiffusionSampler(st.session_state["model"],'heun') sampler = KDiffusionSampler(server_state["model"],'heun')
elif sampler_name == 'k_lms': elif sampler_name == 'k_lms':
sampler = KDiffusionSampler(st.session_state["model"],'lms') sampler = KDiffusionSampler(server_state["model"],'lms')
else: else:
raise Exception("Unknown sampler: " + sampler_name) raise Exception("Unknown sampler: " + sampler_name)
@ -160,18 +160,18 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
mask = (1 - mask) mask = (1 - mask)
mask = np.tile(mask, (4, 1, 1)) mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) mask = mask[None].transpose(0, 1, 2, 3)
mask = torch.from_numpy(mask).to(st.session_state["device"]) mask = torch.from_numpy(mask).to(server_state["device"])
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
st.session_state.modelFS.to(st.session_state["device"] ) server_state["modelFS"].to(server_state["device"] )
init_image = 2. * image - 1. init_image = 2. * image - 1.
init_image = init_image.to(st.session_state["device"]) init_image = init_image.to(server_state["device"])
init_latent = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).get_first_stage_encoding((st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space init_latent = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).get_first_stage_encoding((server_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6 mem = torch.cuda.memory_allocated()/1e6
st.session_state.modelFS.to("cpu") server_state["modelFS"].to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem): while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1) time.sleep(1)
@ -208,7 +208,7 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
x0, z_mask = init_data x0, z_mask = init_data
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] )) z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(server_state["device"] ))
# Obliterate masked image # Obliterate masked image
if z_mask is not None and obliterate: if z_mask is not None and obliterate:
@ -377,8 +377,8 @@ def layout():
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
custom_models_available() custom_models_available()
if st.session_state["CustomModel_available"]: if st.session_state["CustomModel_available"]:
st.session_state["custom_model"] = st.selectbox("Custom Model:", st.session_state["custom_models"], st.session_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"],
index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model), index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
help="Select the model you want to use. This option is only available if you have custom models \ help="Select the model you want to use. This option is only available if you have custom models \
on your 'models/custom' folder. The model name that will be shown here is the same as the name\ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
@ -442,13 +442,13 @@ def layout():
help="Save a file next to the image with informartion about the generation.") help="Save a file next to the image with informartion about the generation.")
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.") save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.")
if st.session_state["GFPGAN_available"]: if server_state["GFPGAN_available"]:
use_GFPGAN = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ use_GFPGAN = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
else: else:
use_GFPGAN = False use_GFPGAN = False
if st.session_state["RealESRGAN_available"]: if server_state["RealESRGAN_available"]:
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN, st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN,
help="Uses the RealESRGAN model to upscale the images after the generation.\ help="Uses the RealESRGAN model to upscale the images after the generation.\
This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")

View File

@ -22,7 +22,7 @@ def sdConceptsBrowser(concepts, key=None):
return component_value return component_value
@st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True) @st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
def getConceptsFromPath(page, conceptPerPage, searchText=""): def getConceptsFromPath(page, conceptPerPage, searchText=""):
#print("getConceptsFromPath", "page:", page, "conceptPerPage:", conceptPerPage, "searchText:", searchText) #print("getConceptsFromPath", "page:", page, "conceptPerPage:", conceptPerPage, "searchText:", searchText)
# get the path where the concepts are stored # get the path where the concepts are stored
@ -97,7 +97,6 @@ def getConceptsFromPath(page, conceptPerPage, searchText=""):
#print("Results:", [c["name"] for c in concepts]) #print("Results:", [c["name"] for c in concepts])
return concepts return concepts
@st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True) @st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True)
def imageToBase64(image): def imageToBase64(image):
import io import io
@ -108,7 +107,7 @@ def imageToBase64(image):
return img_str return img_str
@st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True) @st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
def getTotalNumberOfConcepts(searchText=""): def getTotalNumberOfConcepts(searchText=""):
# get the path where the concepts are stored # get the path where the concepts are stored
path = os.path.join( path = os.path.join(
@ -138,7 +137,7 @@ def layout():
# Concept Library # Concept Library
with tab_library: with tab_library:
downloaded_concepts_count = getTotalNumberOfConcepts() downloaded_concepts_count = getTotalNumberOfConcepts()
concepts_per_page = 12 concepts_per_page = st.session_state["defaults"].concepts_library.concepts_per_page
if not "results" in st.session_state: if not "results" in st.session_state:
st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, "") st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, "")
@ -178,7 +177,7 @@ def layout():
# Previous page # Previous page
with _previous_page: with _previous_page:
if st.button("<", key="cl_previous_page"): if st.button("Previous", key="cl_previous_page"):
st.session_state["cl_current_page"] -= 1 st.session_state["cl_current_page"] -= 1
if st.session_state["cl_current_page"] <= 0: if st.session_state["cl_current_page"] <= 0:
st.session_state["cl_current_page"] = last_page st.session_state["cl_current_page"] = last_page
@ -190,7 +189,7 @@ def layout():
# Next page # Next page
with _next_page: with _next_page:
if st.button(">", key="cl_next_page"): if st.button("Next", key="cl_next_page"):
st.session_state["cl_current_page"] += 1 st.session_state["cl_current_page"] += 1
if st.session_state["cl_current_page"] > last_page: if st.session_state["cl_current_page"] > last_page:
st.session_state["cl_current_page"] = 1 st.session_state["cl_current_page"] = 1

View File

@ -4,6 +4,10 @@ from webui_streamlit import st
# streamlit imports # streamlit imports
from streamlit import StopException from streamlit import StopException
#streamlit components section
from streamlit_server_state import server_state, server_state_lock
#other imports #other imports
import warnings import warnings
@ -55,6 +59,7 @@ except:
# remove some annoying deprecation warnings that show every now and then. # remove some annoying deprecation warnings that show every now and then.
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init() mimetypes.init()
@ -153,15 +158,17 @@ def human_readable_size(size, decimal_places=3):
size /= 1024.0 size /= 1024.0
return f"{size:.{decimal_places}f}{unit}" return f"{size:.{decimal_places}f}{unit}"
@retry(tries=5)
def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus", def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus",
CustomModel_available=False, custom_model="Stable Diffusion v1.4"): CustomModel_available=False, custom_model="Stable Diffusion v1.4"):
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
print ("Loading models.") print ("Loading models.")
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].text("Loading models...") st.session_state["progress_bar_text"].text("Loading models...")
# Generate random run ID # Generate random run ID
# Used to link runs linked w/ continue_prev_run which is not yet implemented # Used to link runs linked w/ continue_prev_run which is not yet implemented
# Use URL and filesystem safe version just in case. # Use URL and filesystem safe version just in case.
@ -171,83 +178,95 @@ def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=Fals
# check what models we want to use and if the they are already loaded. # check what models we want to use and if the they are already loaded.
with server_state_lock["GFPGAN"]:
if use_GFPGAN: if use_GFPGAN:
if "GFPGAN" in st.session_state: if "GFPGAN" in server_state:
print("GFPGAN already loaded") print("GFPGAN already loaded")
else: else:
# Load GFPGAN # Load GFPGAN
if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir): if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
try: try:
st.session_state["GFPGAN"] = load_GFPGAN() server_state["GFPGAN"] = load_GFPGAN()
print("Loaded GFPGAN") print("Loaded GFPGAN")
except Exception: except Exception:
import traceback import traceback
print("Error loading GFPGAN:", file=sys.stderr) print("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else: else:
if "GFPGAN" in st.session_state: if "GFPGAN" in server_state:
del st.session_state["GFPGAN"] del server_state["GFPGAN"]
with server_state_lock["RealESRGAN"]:
if use_RealESRGAN: if use_RealESRGAN:
if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model: if "RealESRGAN" in server_state and server_state["RealESRGAN"].model.name == RealESRGAN_model:
print("RealESRGAN already loaded") print("RealESRGAN already loaded")
else: else:
#Load RealESRGAN #Load RealESRGAN
try: try:
# We first remove the variable in case it has something there, # We first remove the variable in case it has something there,
# some errors can load the model incorrectly and leave things in memory. # some errors can load the model incorrectly and leave things in memory.
del st.session_state["RealESRGAN"] del server_state["RealESRGAN"]
except KeyError: except KeyError:
pass pass
if os.path.exists(st.session_state["defaults"].general.RealESRGAN_dir): if os.path.exists(st.session_state["defaults"].general.RealESRGAN_dir):
# st.session_state is used for keeping the models in memory across multiple pages or runs. # st.session_state is used for keeping the models in memory across multiple pages or runs.
st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model) server_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model)
print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name) print("Loaded RealESRGAN with model "+ server_state["RealESRGAN"].model.name)
else: else:
if "RealESRGAN" in st.session_state: if "RealESRGAN" in server_state:
del st.session_state["RealESRGAN"] del server_state["RealESRGAN"]
if "model" in st.session_state: with server_state_lock["model"], server_state_lock["modelCS"], server_state_lock["modelFS"], server_state_lock["loaded_model"]:
if "model" in st.session_state and st.session_state["loaded_model"] == custom_model:
if "model" in server_state:
if "model" in server_state and server_state["loaded_model"] == custom_model:
# TODO: check if the optimized mode was changed? # TODO: check if the optimized mode was changed?
if "pipe" in st.session_state:
del st.session_state.pipe
print("Model already loaded") print("Model already loaded")
return return
else: else:
try: try:
del st.session_state.model del server_state["model"]
del st.session_state.modelCS del server_state["modelCS"]
del st.session_state.modelFS del server_state["modelFS"]
del st.session_state.loaded_model del server_state["loaded_model"]
if "pipe" in st.session_state:
del st.session_state.pipe
except KeyError: except KeyError:
pass pass
# if the model from txt2vid is in memory we need to remove it to improve performance.
with server_state_lock["pipe"]:
if "pipe" in server_state:
del server_state["pipe"]
# At this point the model is either # At this point the model is either
# is not loaded yet or have been evicted: # is not loaded yet or have been evicted:
# load new model into memory # load new model into memory
st.session_state.custom_model = custom_model server_state["custom_model"] = custom_model
config, device, model, modelCS, modelFS = load_sd_model(custom_model) config, device, model, modelCS, modelFS = load_sd_model(custom_model)
st.session_state.device = device server_state["device"] = device
st.session_state.model = model server_state["model"] = model
st.session_state.modelCS = modelCS
st.session_state.modelFS = modelFS server_state["modelCS"] = modelCS
st.session_state.loaded_model = custom_model server_state["modelFS"] = modelFS
server_state["loaded_model"] = custom_model
#trying to disable multiprocessing as it makes it so streamlit cant stop when the
# model is loaded in memory and you need to kill the process sometimes.
try:
server_state["model"].args.use_multiprocessing_for_evaluation = False
except:
pass
if st.session_state.defaults.general.enable_attention_slicing: if st.session_state.defaults.general.enable_attention_slicing:
st.session_state.model.enable_attention_slicing() server_state["model"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage: if st.session_state.defaults.general.enable_minimal_memory_usage:
st.session_state.model.enable_minimal_memory_usage() server_state["model"].enable_minimal_memory_usage()
print("Model loaded.") print("Model loaded.")
@ -584,6 +603,7 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
def to_d(x, sigma, denoised): def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative.""" """Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim) return (x - denoised) / append_dims(sigma, x.ndim)
def linear_multistep_coeff(order, t, i, j): def linear_multistep_coeff(order, t, i, j):
if order - 1 > i: if order - 1 > i:
raise ValueError(f'Order {order} too high for step {i}') raise ValueError(f'Order {order} too high for step {i}')
@ -656,6 +676,7 @@ def torch_gc():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
@retry(tries=5)
def load_GFPGAN(): def load_GFPGAN():
model_name = 'GFPGANv1.3' model_name = 'GFPGANv1.3'
model_path = os.path.join(st.session_state['defaults'].general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') model_path = os.path.join(st.session_state['defaults'].general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
@ -673,6 +694,7 @@ def load_GFPGAN():
instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}")) instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
return instance return instance
@retry(tries=5)
def load_RealESRGAN(model_name: str): def load_RealESRGAN(model_name: str):
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
RealESRGAN_models = { RealESRGAN_models = {
@ -700,6 +722,7 @@ def load_RealESRGAN(model_name: str):
return instance return instance
# #
@retry(tries=5)
def load_LDSR(checking=False): def load_LDSR(checking=False):
model_name = 'model' model_name = 'model'
yaml_name = 'project' yaml_name = 'project'
@ -719,6 +742,8 @@ def load_LDSR(checking=False):
# #
LDSR = None LDSR = None
@retry(tries=5)
def try_loading_LDSR(model_name: str,checking=False): def try_loading_LDSR(model_name: str,checking=False):
global LDSR global LDSR
if os.path.exists(st.session_state['defaults'].general.LDSR_dir): if os.path.exists(st.session_state['defaults'].general.LDSR_dir):
@ -739,8 +764,10 @@ def try_loading_LDSR(model_name: str,checking=False):
# Loads Stable Diffusion model by name # Loads Stable Diffusion model by name
#@retry(tries=5)
def load_sd_model(model_name: str) -> [any, any, any, any, any]: def load_sd_model(model_name: str) -> [any, any, any, any, any]:
ckpt_path = st.session_state.defaults.general.default_model_path ckpt_path = st.session_state.defaults.general.default_model_path
if model_name != st.session_state.defaults.general.default_model: if model_name != st.session_state.defaults.general.default_model:
ckpt_path = os.path.join("models", "custom", f"{model_name}.ckpt") ckpt_path = os.path.join("models", "custom", f"{model_name}.ckpt")
@ -864,12 +891,12 @@ def generation_callback(img, i=0):
# It can probably be done in a better way for someone who knows what they're doing. I don't. # It can probably be done in a better way for someone who knows what they're doing. I don't.
#print (img,isinstance(img, torch.Tensor)) #print (img,isinstance(img, torch.Tensor))
if isinstance(img, torch.Tensor): if isinstance(img, torch.Tensor):
x_samples_ddim = (st.session_state["model"].to('cuda') if not st.session_state['defaults'].general.optimized else st.session_state.modelFS.to('cuda') x_samples_ddim = (server_state["model"].to('cuda') if not st.session_state['defaults'].general.optimized else server_state["modelFS"].to('cuda')
).decode_first_stage(img).to('cuda') ).decode_first_stage(img).to('cuda')
else: else:
# When using the k Diffusion samplers they return a dict instead of a tensor that look like this: # When using the k Diffusion samplers they return a dict instead of a tensor that look like this:
# {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}
x_samples_ddim = (st.session_state["model"].to('cuda') if not st.session_state['defaults'].general.optimized else st.session_state.modelFS.to('cuda') x_samples_ddim = (server_state["model"].to('cuda') if not st.session_state['defaults'].general.optimized else server_state["modelFS"].to('cuda')
).decode_first_stage(img["denoised"]).to('cuda') ).decode_first_stage(img["denoised"]).to('cuda')
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@ -954,6 +981,7 @@ def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995):
return v2 return v2
# #
@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list): def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list):
"""Find the optimal update_preview_frequency value maximizing """Find the optimal update_preview_frequency value maximizing
performance while minimizing the time between updates.""" performance while minimizing the time between updates."""
@ -989,8 +1017,8 @@ def get_font(fontsize):
raise Exception(f"No usable font found (tried {', '.join(fonts)})") raise Exception(f"No usable font found (tried {', '.join(fonts)})")
def load_embeddings(fp): def load_embeddings(fp):
if fp is not None and hasattr(st.session_state["model"], "embedding_manager"): if fp is not None and hasattr(server_state["model"], "embedding_manager"):
st.session_state["model"].embedding_manager.load(fp['name']) server_state["model"].embedding_manager.load(fp['name'])
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
@ -1148,10 +1176,10 @@ def enable_minimal_memory_usage(model):
def check_prompt_length(prompt, comments): def check_prompt_length(prompt, comments):
"""this function tests if prompt is too long, and if so, adds a message to comments""" """this function tests if prompt is too long, and if so, adds a message to comments"""
tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer tokenizer = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).cond_stage_model.tokenizer
max_length = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.max_length max_length = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).cond_stage_model.max_length
info = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, info = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
return_overflowing_tokens=True, padding="max_length", return_tensors="pt") return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
ovf = info['overflowing_tokens'][0] ovf = info['overflowing_tokens'][0]
overflowing_count = ovf.shape[0] overflowing_count = ovf.shape[0]
@ -1169,17 +1197,17 @@ def custom_models_available():
# #
# Allow for custom models to be used instead of the default one, # Allow for custom models to be used instead of the default one,
# an example would be Waifu-Diffusion or any other fine tune of stable diffusion # an example would be Waifu-Diffusion or any other fine tune of stable diffusion
st.session_state["custom_models"]:sorted = [] server_state["custom_models"]:sorted = []
for root, dirs, files in os.walk(os.path.join("models", "custom")): for root, dirs, files in os.walk(os.path.join("models", "custom")):
for file in files: for file in files:
if os.path.splitext(file)[1] == '.ckpt': if os.path.splitext(file)[1] == '.ckpt':
st.session_state["custom_models"].append(os.path.splitext(file)[0]) server_state["custom_models"].append(os.path.splitext(file)[0])
if len(st.session_state["custom_models"]) > 0: if len(server_state["custom_models"]) > 0:
st.session_state["CustomModel_available"] = True st.session_state["CustomModel_available"] = True
st.session_state["custom_models"].append("Stable Diffusion v1.4") server_state["custom_models"].append("Stable Diffusion v1.4")
else: else:
st.session_state["CustomModel_available"] = False st.session_state["CustomModel_available"] = False
@ -1217,7 +1245,7 @@ def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, widt
target="txt2img" if init_img is None else "img2img", target="txt2img" if init_img is None else "img2img",
prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name, prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name,
ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale, ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale,
seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights, model_name=st.session_state["loaded_model"]) seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights, model_name=server_state["loaded_model"])
# Not yet any use for these, but they bloat up the files: # Not yet any use for these, but they bloat up the files:
# info_dict["init_img"] = init_img # info_dict["init_img"] = init_img
# info_dict["init_mask"] = init_mask # info_dict["init_mask"] = init_mask
@ -1386,8 +1414,8 @@ def process_images(
if prompt_tokens: if prompt_tokens:
# compviz # compviz
tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer tokenizer = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).cond_stage_model.tokenizer
text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer text_encoder = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).cond_stage_model.transformer
# diffusers # diffusers
#tokenizer = pipe.tokenizer #tokenizer = pipe.tokenizer
@ -1471,7 +1499,7 @@ def process_images(
output_images = [] output_images = []
grid_captions = [] grid_captions = []
stats = [] stats = []
with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not st.session_state['defaults'].general.optimized else nullcontext()): with torch.no_grad(), precision_scope("cuda"), (server_state["model"].ema_scope() if not st.session_state['defaults'].general.optimized else nullcontext()):
init_data = func_init() init_data = func_init()
tic = time.time() tic = time.time()
@ -1497,9 +1525,9 @@ def process_images(
print(prompt) print(prompt)
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
st.session_state.modelCS.to(st.session_state['defaults'].general.gpu) server_state["modelCS"].to(st.session_state['defaults'].general.gpu)
uc = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(len(prompts) * [negprompt]) uc = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).get_learned_conditioning(len(prompts) * [negprompt])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
@ -1513,23 +1541,23 @@ def process_images(
c = torch.zeros_like(uc) # i dont know if this is correct.. but it works c = torch.zeros_like(uc) # i dont know if this is correct.. but it works
for i in range(0, len(weighted_subprompts)): for i in range(0, len(weighted_subprompts)):
# note if alpha negative, it functions same as torch.sub # note if alpha negative, it functions same as torch.sub
c = torch.add(c, (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1]) c = torch.add(c, (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1])
else: # just behave like usual else: # just behave like usual
c = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(prompts) c = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).get_learned_conditioning(prompts)
shape = [opt_C, height // opt_f, width // opt_f] shape = [opt_C, height // opt_f, width // opt_f]
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6 mem = torch.cuda.memory_allocated()/1e6
st.session_state.modelCS.to("cpu") server_state["modelCS"].to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem): while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1) time.sleep(1)
if noise_mode == 1 or noise_mode == 3: if noise_mode == 1 or noise_mode == 3:
# TODO params for find_noise_to_image # TODO params for find_noise_to_image
x = torch.cat(batch_size * [find_noise_for_image( x = torch.cat(batch_size * [find_noise_for_image(
st.session_state["model"], st.session_state["device"], server_state["model"], server_state["device"],
init_img.convert('RGB'), '', find_noise_steps, 0.0, normalize=True, init_img.convert('RGB'), '', find_noise_steps, 0.0, normalize=True,
generation_callback=generation_callback, generation_callback=generation_callback,
)], dim=0) )], dim=0)
@ -1551,9 +1579,9 @@ def process_images(
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
st.session_state.modelFS.to(st.session_state['defaults'].general.gpu) server_state["modelFS"].to(st.session_state['defaults'].general.gpu)
x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(samples_ddim) x_samples_ddim = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
run_images = [] run_images = []
@ -1567,7 +1595,7 @@ def process_images(
full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt) full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt)
sanitized_prompt = sanitized_prompt[:220-len(full_path)] sanitized_prompt = sanitized_prompt[:200-len(full_path)]
sample_path_i = os.path.join(sample_path, sanitized_prompt) sample_path_i = os.path.join(sample_path, sanitized_prompt)
#print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}") #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}")
@ -1580,7 +1608,7 @@ def process_images(
full_path = os.path.join(os.getcwd(), sample_path) full_path = os.path.join(os.getcwd(), sample_path)
sample_path_i = sample_path sample_path_i = sample_path
base_count = get_next_sequence_number(sample_path_i) base_count = get_next_sequence_number(sample_path_i)
filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)] #same as before
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
@ -1590,11 +1618,11 @@ def process_images(
st.session_state["preview_image"].image(image) st.session_state["preview_image"].image(image)
if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN: if use_GFPGAN and server_state["GFPGAN"] is not None and not use_RealESRGAN:
st.session_state["progress_bar_text"].text("Running GFPGAN on image %d of %d..." % (i+1, len(x_samples_ddim))) st.session_state["progress_bar_text"].text("Running GFPGAN on image %d of %d..." % (i+1, len(x_samples_ddim)))
#skip_save = True # #287 >_> #skip_save = True # #287 >_>
torch_gc() torch_gc()
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1] gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample) gfpgan_image = Image.fromarray(gfpgan_sample)
gfpgan_filename = original_filename + '-gfpgan' gfpgan_filename = original_filename + '-gfpgan'
@ -1602,7 +1630,7 @@ def process_images(
save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback,
uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta, uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta,
n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"]) n_iter, batch_size, i, denoising_strength, resize_mode, False, server_state["loaded_model"])
output_images.append(gfpgan_image) #287 output_images.append(gfpgan_image) #287
run_images.append(gfpgan_image) run_images.append(gfpgan_image)
@ -1610,16 +1638,16 @@ def process_images(
if simple_templating: if simple_templating:
grid_captions.append( captions[i] + "\ngfpgan" ) grid_captions.append( captions[i] + "\ngfpgan" )
elif use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN: elif use_RealESRGAN and server_state["RealESRGAN"] is not None and not use_GFPGAN:
st.session_state["progress_bar_text"].text("Running RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim))) st.session_state["progress_bar_text"].text("Running RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim)))
#skip_save = True # #287 >_> #skip_save = True # #287 >_>
torch_gc() torch_gc()
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: if server_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name) #try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1]) output, img_mode = server_state["RealESRGAN"].enhance(x_sample[:,:,::-1])
esrgan_filename = original_filename + '-esrgan4x' esrgan_filename = original_filename + '-esrgan4x'
esrgan_sample = output[:,:,::-1] esrgan_sample = output[:,:,::-1]
esrgan_image = Image.fromarray(esrgan_sample) esrgan_image = Image.fromarray(esrgan_sample)
@ -1630,7 +1658,7 @@ def process_images(
save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"]) save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, server_state["loaded_model"])
output_images.append(esrgan_image) #287 output_images.append(esrgan_image) #287
run_images.append(esrgan_image) run_images.append(esrgan_image)
@ -1638,25 +1666,25 @@ def process_images(
if simple_templating: if simple_templating:
grid_captions.append( captions[i] + "\nesrgan" ) grid_captions.append( captions[i] + "\nesrgan" )
elif use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None: elif use_RealESRGAN and server_state["RealESRGAN"] is not None and use_GFPGAN and server_state["GFPGAN"] is not None:
st.session_state["progress_bar_text"].text("Running GFPGAN+RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim))) st.session_state["progress_bar_text"].text("Running GFPGAN+RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim)))
#skip_save = True # #287 >_> #skip_save = True # #287 >_>
torch_gc() torch_gc()
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1] gfpgan_sample = restored_img[:,:,::-1]
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: if server_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name) #try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1]) output, img_mode = server_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
gfpgan_esrgan_sample = output[:,:,::-1] gfpgan_esrgan_sample = output[:,:,::-1]
gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"]) save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, server_state["loaded_model"])
output_images.append(gfpgan_esrgan_image) #287 output_images.append(gfpgan_esrgan_image) #287
run_images.append(gfpgan_esrgan_image) run_images.append(gfpgan_esrgan_image)
@ -1674,16 +1702,16 @@ def process_images(
init_img = init_img.convert('RGB') init_img = init_img.convert('RGB')
image = image.convert('RGB') image = image.convert('RGB')
if use_RealESRGAN and st.session_state["RealESRGAN"] is not None: if use_RealESRGAN and server_state["RealESRGAN"] is not None:
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: if server_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name) #try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8)) output, img_mode = server_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output) init_img = Image.fromarray(output)
init_img = init_img.convert('RGB') init_img = init_img.convert('RGB')
output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8)) output, img_mode = server_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output) init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L') init_mask = init_mask.convert('L')
@ -1692,7 +1720,7 @@ def process_images(
if save_individual_images: if save_individual_images:
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, st.session_state["loaded_model"]) save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, server_state["loaded_model"])
#if add_original_image or not simple_templating: #if add_original_image or not simple_templating:
#output_images.append(image) #output_images.append(image)
@ -1701,7 +1729,7 @@ def process_images(
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6 mem = torch.cuda.memory_allocated()/1e6
st.session_state.modelFS.to("cpu") server_state["modelFS"].to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem): while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1) time.sleep(1)
@ -1735,7 +1763,7 @@ def process_images(
output_images.insert(0, grid) output_images.insert(0, grid)
grid_count = get_next_sequence_number(outpath, 'grid-') grid_count = get_next_sequence_number(outpath, 'grid-')
grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}" grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:200-len(full_path)])}.{grid_ext}"
grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True) grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True)
toc = time.time() toc = time.time()
@ -1745,7 +1773,7 @@ def process_images(
info = f""" info = f"""
{prompt} {prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip() Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and server_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and server_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
stats = f''' stats = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''

View File

@ -1,29 +1,49 @@
import gc
import inspect import inspect
import warnings import warnings
from tqdm.auto import tqdm
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
from diffusers import ModelMixin
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import \ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
StableDiffusionSafetyChecker from diffusers import StableDiffusionPipelineOutput
from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler, #from diffusers.safety_checker import StableDiffusionSafetyChecker
PNDMScheduler)
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
class StableDiffusionPipeline(DiffusionPipeline): class StableDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
def __init__( def __init__(
self, self,
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
@ -37,10 +57,45 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_minimal_memory_usage(self):
"""Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs"""
self.unet.to(torch.float16).to(torch.device("cuda"))
self.enable_attention_slicing(1)
torch.cuda.empty_cache()
gc.collect()
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt: Optional[Union[str, List[str]]] = None, prompt: Union[str, List[str]],
height: Optional[int] = 512, height: Optional[int] = 512,
width: Optional[int] = 512, width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
@ -48,38 +103,75 @@ class StableDiffusionPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
text_embeddings: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs, **kwargs,
): ):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if "torch_device" in kwargs: if "torch_device" in kwargs:
device = kwargs.pop("torch_device") # device = kwargs.pop("torch_device")
warnings.warn( warnings.warn(
"`torch_device` is deprecated as an input argument to `__call__` and" "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
" will be removed in v0.3.0. Consider using `pipe.to(torch_device)`" " Consider using `pipe.to(torch_device)` instead."
" instead."
) )
# Set device as before (to be removed in 0.3.0) # Set device as before (to be removed in 0.3.0)
if device is None: # if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" # device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device) # self.to(device)
if text_embeddings is None:
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif isinstance(prompt, list): elif isinstance(prompt, list):
batch_size = len(prompt) batch_size = len(prompt)
else: else:
raise ValueError( raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError( raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
"`height` and `width` have to be divisible by 8 but are"
f" {height} and {width}."
)
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer( text_input = self.tokenizer(
@ -89,9 +181,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] text_embeddings = self.text_encoder(text_input.input_ids.to(self.text_encoder.device))[0].to(self.unet.device)
else:
batch_size = text_embeddings.shape[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@ -99,17 +189,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
# max_length = text_input.input_ids.shape[-1] max_length = text_input.input_ids.shape[-1]
max_length = 77 # self.tokenizer.model_max_length
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
[""] * batch_size, [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
padding="max_length", )
max_length=max_length, uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.text_encoder.device))[0].to(
return_tensors="pt", self.unet.device
) )
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(self.device)
)[0]
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
@ -117,25 +203,25 @@ class StableDiffusionPipeline(DiffusionPipeline):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None: if latents is None:
latents = torch.randn( latents = torch.randn(
latents_shape, latents_shape,
generator=generator, generator=generator,
device=self.device, device=latents_device,
) )
else: else:
if latents.shape != latents_shape: if latents.shape != latents_shape:
raise ValueError( raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
f"Unexpected latents shape, got {latents.shape}, expected"
f" {latents_shape}"
)
latents = latents.to(self.device) latents = latents.to(self.device)
# set timesteps # set timesteps
accepts_offset = "offset" in set( accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
)
extra_set_kwargs = {} extra_set_kwargs = {}
if accepts_offset: if accepts_offset:
extra_set_kwargs["offset"] = 1 extra_set_kwargs["offset"] = 1
@ -150,18 +236,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1] # and should be between [0, 1]
accepts_eta = "eta" in set( accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {} extra_step_kwargs = {}
if accepts_eta: if accepts_eta:
extra_step_kwargs["eta"] = eta extra_step_kwargs["eta"] = eta
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = ( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i] sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
@ -169,65 +251,43 @@ class StableDiffusionPipeline(DiffusionPipeline):
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings latent_model_input.to(self.unet.device), t.to(self.unet.device), encoder_hidden_states=text_embeddings
)["sample"] ).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step( latents = self.scheduler.step(
noise_pred, i, latents, **extra_step_kwargs noise_pred, i, latents.to(self.unet.device), **extra_step_kwargs
)["prev_sample"] ).prev_sample
else: else:
latents = self.scheduler.step( latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs noise_pred, t.to(self.unet.device), latents.to(self.unet.device), **extra_step_kwargs
)["prev_sample"] ).prev_sample
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents.to(self.vae.device)).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.to(self.vae.device).to(self.vae.device).cpu().permute(0, 2, 3, 1).numpy()
safety_cheker_input = self.feature_extractor( # run safety checker
self.numpy_to_pil(image), return_tensors="pt" safety_cheker_input = (
).to(self.device) self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt")
image, has_nsfw_concept = self.safety_checker( .to(self.vae.device)
images=image, clip_input=safety_cheker_input.pixel_values .to(self.vae.dtype)
) )
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept} if not return_dict:
return (image, has_nsfw_concept)
def embed_text(self, text): return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
"""Helper to embed some text"""
with torch.autocast("cuda"):
text_input = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
return embed
class NoCheck(ModelMixin):
"""Can be used in place of safety checker. Use responsibly and at your own risk."""
def __init__(self):
super().__init__()
self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3)))
def forward(self, images=None, **kwargs):
return images, [False]

View File

@ -36,14 +36,14 @@ class plugin_info():
if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
GFPGAN_available = True server_state["GFPGAN_available"] = True
else: else:
GFPGAN_available = False server_state["GFPGAN_available"] = False
if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")): if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")):
RealESRGAN_available = True server_state["RealESRGAN_available"] = True
else: else:
RealESRGAN_available = False server_state["RealESRGAN_available"] = False
# #
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str, def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str,
@ -69,21 +69,21 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_na
#use_RealESRGAN = 8 in toggles #use_RealESRGAN = 8 in toggles
if sampler_name == 'PLMS': if sampler_name == 'PLMS':
sampler = PLMSSampler(st.session_state["model"]) sampler = PLMSSampler(server_state["model"])
elif sampler_name == 'DDIM': elif sampler_name == 'DDIM':
sampler = DDIMSampler(st.session_state["model"]) sampler = DDIMSampler(server_state["model"])
elif sampler_name == 'k_dpm_2_a': elif sampler_name == 'k_dpm_2_a':
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral')
elif sampler_name == 'k_dpm_2': elif sampler_name == 'k_dpm_2':
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') sampler = KDiffusionSampler(server_state["model"],'dpm_2')
elif sampler_name == 'k_euler_a': elif sampler_name == 'k_euler_a':
sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') sampler = KDiffusionSampler(server_state["model"],'euler_ancestral')
elif sampler_name == 'k_euler': elif sampler_name == 'k_euler':
sampler = KDiffusionSampler(st.session_state["model"],'euler') sampler = KDiffusionSampler(server_state["model"],'euler')
elif sampler_name == 'k_heun': elif sampler_name == 'k_heun':
sampler = KDiffusionSampler(st.session_state["model"],'heun') sampler = KDiffusionSampler(server_state["model"],'heun')
elif sampler_name == 'k_lms': elif sampler_name == 'k_lms':
sampler = KDiffusionSampler(st.session_state["model"],'lms') sampler = KDiffusionSampler(server_state["model"],'lms')
else: else:
raise Exception("Unknown sampler: " + sampler_name) raise Exception("Unknown sampler: " + sampler_name)
@ -209,8 +209,8 @@ def layout():
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
custom_models_available() custom_models_available()
if st.session_state.CustomModel_available: if st.session_state.CustomModel_available:
st.session_state.custom_model = st.selectbox("Custom Model:", st.session_state.custom_models, server_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"],
index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model), index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
help="Select the model you want to use. This option is only available if you have custom models \ help="Select the model you want to use. This option is only available if you have custom models \
on your 'models/custom' folder. The model name that will be shown here is the same as the name\ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
@ -243,13 +243,13 @@ def layout():
write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files, help="Save a file next to the image with informartion about the generation.") write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files, help="Save a file next to the image with informartion about the generation.")
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.") save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.")
if st.session_state["GFPGAN_available"]: if server_state["GFPGAN_available"]:
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
else: else:
st.session_state["use_GFPGAN"] = False st.session_state["use_GFPGAN"] = False
if st.session_state["RealESRGAN_available"]: if server_state["RealESRGAN_available"]:
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2img.use_RealESRGAN, st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2img.use_RealESRGAN,
help="Uses the RealESRGAN model to upscale the images after the generation.\ help="Uses the RealESRGAN model to upscale the images after the generation.\
This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")

View File

@ -6,6 +6,9 @@ from sd_utils import *
from streamlit import StopException from streamlit import StopException
from streamlit.elements import image as STImage from streamlit.elements import image as STImage
#streamlit components section
from streamlit_server_state import server_state, server_state_lock
#other imports #other imports
import os import os
@ -19,11 +22,12 @@ from io import BytesIO
import imageio import imageio
from slugify import slugify from slugify import slugify
# Temp imports
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
PNDMScheduler PNDMScheduler
# Temp imports
# end of imports # end of imports
#--------------------------------------------------------------------------------------------------------------- #---------------------------------------------------------------------------------------------------------------
@ -201,6 +205,66 @@ def diffuse(
return image2 return image2
# #
@st.experimental_singleton(show_spinner=False, suppress_st_warning=True)
def load_diffusers_model(weights_path,torch_device):
with server_state_lock["model"]:
if "model" in server_state:
del server_state["model"]
with server_state_lock["pipe"]:
try:
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
if st.session_state["weights_path"] != weights_path:
del st.session_state["weights_path"]
st.session_state["weights_path"] = weights_path
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
server_state["pipe"].unet.to(torch_device)
server_state["pipe"].vae.to(torch_device)
server_state["pipe"].text_encoder.to(torch_device)
if st.session_state.defaults.general.enable_attention_slicing:
server_state["pipe"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage:
server_state["pipe"].enable_minimal_memory_usage()
print("Tx2Vid Model Loaded")
else:
print("Tx2Vid Model already Loaded")
except:
#del st.session_state["weights_path"]
#del server_state["pipe"]
st.session_state["weights_path"] = weights_path
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
server_state["pipe"].unet.to(torch_device)
server_state["pipe"].vae.to(torch_device)
server_state["pipe"].text_encoder.to(torch_device)
if st.session_state.defaults.general.enable_attention_slicing:
server_state["pipe"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage:
server_state["pipe"].enable_minimal_memory_usage()
print("Tx2Vid Model Loaded")
#
def txt2vid( def txt2vid(
# -------------------------------------- # --------------------------------------
# args you probably want to change # args you probably want to change
@ -337,59 +401,12 @@ def txt2vid(
#print (st.session_state["weights_path"] != weights_path) #print (st.session_state["weights_path"] != weights_path)
try: load_diffusers_model(weights_path, torch_device)
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
if st.session_state["weights_path"] != weights_path:
del st.session_state["weights_path"]
st.session_state["weights_path"] = weights_path server_state["pipe"].scheduler = SCHEDULERS[scheduler]
st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
st.session_state["pipe"].unet.to(torch_device) server_state["pipe"].use_multiprocessing_for_evaluation = False
st.session_state["pipe"].vae.to(torch_device) server_state["pipe"].use_multiprocessed_decoding = False
st.session_state["pipe"].text_encoder.to(torch_device)
if st.session_state.defaults.general.enable_attention_slicing:
st.session_state["pipe"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage:
st.session_state["pipe"].enable_minimal_memory_usage()
print("Tx2Vid Model Loaded")
else:
print("Tx2Vid Model already Loaded")
except:
#del st.session_state["weights_path"]
#del st.session_state["pipe"]
st.session_state["weights_path"] = weights_path
st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
st.session_state["pipe"].unet.to(torch_device)
st.session_state["pipe"].vae.to(torch_device)
st.session_state["pipe"].text_encoder.to(torch_device)
if st.session_state.defaults.general.enable_attention_slicing:
st.session_state["pipe"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage:
st.session_state["pipe"].enable_minimal_memory_usage()
print("Tx2Vid Model Loaded")
st.session_state["pipe"].scheduler = SCHEDULERS[scheduler]
if do_loop: if do_loop:
prompts = str([prompts, prompts]) prompts = str([prompts, prompts])
@ -399,8 +416,8 @@ def txt2vid(
#seeds.append(first_seed) #seeds.append(first_seed)
# get the conditional text embeddings based on the prompt # get the conditional text embeddings based on the prompt
text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt") text_input = server_state["pipe"].tokenizer(prompts, padding="max_length", max_length=server_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt")
cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768] cond_embeddings = server_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768]
# #
if st.session_state.defaults.general.use_sd_concepts_library: if st.session_state.defaults.general.use_sd_concepts_library:
@ -434,7 +451,7 @@ def txt2vid(
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>") load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>")
# sample a source # sample a source
init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) init1 = torch.randn((1, server_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
# iterate the loop # iterate the loop
@ -451,7 +468,7 @@ def txt2vid(
st.session_state["current_frame"] = frame_index st.session_state["current_frame"] = frame_index
# sample the destination # sample the destination
init2 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) init2 = torch.randn((1, server_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
for i, t in enumerate(np.linspace(0, 1, num_steps)): for i, t in enumerate(np.linspace(0, 1, num_steps)):
start = timeit.default_timer() start = timeit.default_timer()
@ -465,9 +482,9 @@ def txt2vid(
init = slerp(gpu, float(t), init1, init2) init = slerp(gpu, float(t), init1, init2)
with autocast("cuda"): with autocast("cuda"):
image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta) image = diffuse(server_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)
if st.session_state["save_individual_images"] and not st.session_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]: if st.session_state["save_individual_images"] and not server_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
#im = Image.fromarray(image) #im = Image.fromarray(image)
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index) outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
image.save(outpath, quality=quality) image.save(outpath, quality=quality)
@ -481,13 +498,13 @@ def txt2vid(
# #
#try: #try:
#if st.session_state["use_GFPGAN"] and st.session_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]: #if server_state["use_GFPGAN"] and server_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
if st.session_state["use_GFPGAN"] and st.session_state["GFPGAN"] is not None: if server_state["use_GFPGAN"] and server_state["GFPGAN"] is not None:
#print("Running GFPGAN on image ...") #print("Running GFPGAN on image ...")
st.session_state["progress_bar_text"].text("Running GFPGAN on image ...") st.session_state["progress_bar_text"].text("Running GFPGAN on image ...")
#skip_save = True # #287 >_> #skip_save = True # #287 >_>
torch_gc() torch_gc()
cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(np.array(image)[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(np.array(image)[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1] gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample) gfpgan_image = Image.fromarray(gfpgan_sample)
@ -698,9 +715,9 @@ def layout():
st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.") st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")
if GFPGAN_available: if GFPGAN_available:
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") server_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
else: else:
st.session_state["use_GFPGAN"] = False server_state["use_GFPGAN"] = False
if RealESRGAN_available: if RealESRGAN_available:
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN, st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN,
@ -726,16 +743,16 @@ def layout():
if generate_button: if generate_button:
#print("Loading models") #print("Loading models")
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
#load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"]) #load_models(False, server_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
if st.session_state["use_GFPGAN"]: if server_state["use_GFPGAN"]:
if "GFPGAN" in st.session_state: if "GFPGAN" in st.session_state:
print("GFPGAN already loaded") print("GFPGAN already loaded")
else: else:
# Load GFPGAN # Load GFPGAN
if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir): if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
try: try:
st.session_state["GFPGAN"] = load_GFPGAN() server_state["GFPGAN"] = load_GFPGAN()
print("Loaded GFPGAN") print("Loaded GFPGAN")
except Exception: except Exception:
import traceback import traceback
@ -743,9 +760,9 @@ def layout():
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else: else:
if "GFPGAN" in st.session_state: if "GFPGAN" in st.session_state:
del st.session_state["GFPGAN"] del server_state["GFPGAN"]
#try: try:
# run video generation # run video generation
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu, video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames), num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
@ -801,7 +818,7 @@ def layout():
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont] #st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
#except (StopException, KeyError): except (StopException, KeyError):
#print(f"Received Streamlit StopException") print(f"Received Streamlit StopException")

View File

@ -7,6 +7,7 @@ import streamlit_nested_layout
#streamlit components section #streamlit components section
from st_on_hover_tabs import on_hover_tabs from st_on_hover_tabs import on_hover_tabs
from streamlit_server_state import server_state, server_state_lock
#other imports #other imports
@ -41,6 +42,7 @@ except:
# remove some annoying deprecation warnings that show every now and then. # remove some annoying deprecation warnings that show every now and then.
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# this should force GFPGAN and RealESRGAN onto the selected gpu as well # this should force GFPGAN and RealESRGAN onto the selected gpu as well
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 #os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
@ -70,15 +72,17 @@ def layout():
load_css(True, 'frontend/css/streamlit.main.css') load_css(True, 'frontend/css/streamlit.main.css')
# check if the models exist on their respective folders # check if the models exist on their respective folders
with server_state_lock["GFPGAN_available"]:
if os.path.exists(os.path.join(st.session_state["defaults"].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): if os.path.exists(os.path.join(st.session_state["defaults"].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
st.session_state["GFPGAN_available"] = True server_state["GFPGAN_available"] = True
else: else:
st.session_state["GFPGAN_available"] = False server_state["GFPGAN_available"] = False
with server_state_lock["RealESRGAN_available"]:
if os.path.exists(os.path.join(st.session_state["defaults"].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")): if os.path.exists(os.path.join(st.session_state["defaults"].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")):
st.session_state["RealESRGAN_available"] = True server_state["RealESRGAN_available"] = True
else: else:
st.session_state["RealESRGAN_available"] = False server_state["RealESRGAN_available"] = False
## Allow for custom models to be used instead of the default one, ## Allow for custom models to be used instead of the default one,
## an example would be Waifu-Diffusion or any other fine tune of stable diffusion ## an example would be Waifu-Diffusion or any other fine tune of stable diffusion