Added basic image editor for the img2img tab. (#1613)

Added an image editor for img2img, for now its pretty basic and has some
bugs but it should do the job better than having to create a mask using
an external software and then importing it to the UI every time.
This commit is contained in:
Alejandro Gil 2022-10-28 12:41:21 -07:00 committed by GitHub
commit e9a4e46f66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 2028 additions and 910 deletions

View File

@ -230,7 +230,8 @@ img2img:
step: 0.01
# 0: Keep masked area
# 1: Regenerate only masked area
mask_mode: 0
mask_mode: 1
noise_mode: "Matched Noise"
mask_restore: False
# 0: Just resize
# 1: Crop and resize

View File

@ -33,6 +33,7 @@ from ldm.models.diffusion.plms import PLMSSampler
# streamlit components
from custom_components import sygil_suggestions
from streamlit_drawable_canvas import st_canvas
# Temp imports
@ -381,7 +382,7 @@ def layout():
# creating the page layout using columns
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="medium")
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([2,4,4], gap="medium")
with col1_img2img_layout:
# If we have custom models available on the "models/custom"
@ -426,7 +427,7 @@ def layout():
mask_expander = st.empty()
with mask_expander.expander("Mask"):
mask_mode_list = ["Mask", "Inverted mask", "Image alpha"]
mask_mode = st.selectbox("Mask Mode", mask_mode_list,
mask_mode = st.selectbox("Mask Mode", mask_mode_list, index=st.session_state["defaults"].img2img.mask_mode,
help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\
\"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
)
@ -434,15 +435,32 @@ def layout():
noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
noise_mode = st.selectbox(
"Noise Mode", noise_mode_list,
help=""
)
noise_mode = noise_mode_list.index(noise_mode)
noise_mode = st.selectbox("Noise Mode", noise_mode_list, index=noise_mode_list.index(st.session_state['defaults'].img2img.noise_mode), help="")
#noise_mode = noise_mode_list.index(noise_mode)
find_noise_steps = st.number_input("Find Noise Steps", value=st.session_state['defaults'].img2img.find_noise_steps.value,
min_value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
step=st.session_state['defaults'].img2img.find_noise_steps.step)
# Specify canvas parameters in application
drawing_mode = st.selectbox(
"Drawing tool:",
(
"freedraw",
"transform",
#"line",
"rect",
"circle",
#"polygon",
),
)
stroke_width = st.slider("Stroke width: ", 1, 100, 50)
stroke_color = st.color_picker("Stroke color hex: ", value="#EEEEEE")
bg_color = st.color_picker("Background color hex: ", "#7B6E6E")
display_toolbar = st.checkbox("Display toolbar", True)
#realtime_update = st.checkbox("Update in realtime", True)
with st.expander("Batch Options"):
st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].img2img.batch_count.value,
help="How many iterations or batches of images to generate in total.")
@ -583,55 +601,63 @@ def layout():
editor_image = st.empty()
st.session_state["editor_image"] = editor_image
st.form_submit_button("Refresh")
#if "canvas" not in st.session_state:
st.session_state["canvas"] = st.empty()
masked_image_holder = st.empty()
image_holder = st.empty()
st.form_submit_button("Refresh")
uploaded_images = st.file_uploader(
"Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
help="Upload an image which will be used for the image to image generation.",
)
if uploaded_images:
image = Image.open(uploaded_images).convert('RGBA')
image = Image.open(uploaded_images).convert('RGB')
new_img = image.resize((width, height))
image_holder.image(new_img)
#image_holder.image(new_img)
mask_holder = st.empty()
#mask_holder = st.empty()
uploaded_masks = st.file_uploader(
"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
help="Upload an mask image which will be used for masking the image to image generation.",
)
if uploaded_masks:
mask_expander.expander("Mask", expanded=True)
mask = Image.open(uploaded_masks)
if mask.mode == "RGBA":
mask = mask.convert('RGBA')
background = Image.new('RGBA', mask.size, (0, 0, 0))
mask = Image.alpha_composite(background, mask)
mask = mask.resize((width, height))
mask_holder.image(mask)
#uploaded_masks = st.file_uploader(
#"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
#help="Upload an mask image which will be used for masking the image to image generation.",
#)
if uploaded_images and uploaded_masks:
if mask_mode != 2:
final_img = new_img.copy()
alpha_layer = mask.convert('L')
strength = st.session_state["denoising_strength"]
if mask_mode == 0:
alpha_layer = ImageOps.invert(alpha_layer)
alpha_layer = alpha_layer.point(lambda a: a * strength)
alpha_layer = ImageOps.invert(alpha_layer)
elif mask_mode == 1:
alpha_layer = alpha_layer.point(lambda a: a * strength)
alpha_layer = ImageOps.invert(alpha_layer)
#
# Create a canvas component
with st.session_state["canvas"]:
st.session_state["uploaded_masks"] = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color=bg_color,
background_image=image if uploaded_images else None,
update_streamlit=True,
width=width,
height=height,
drawing_mode=drawing_mode,
initial_drawing=st.session_state["uploaded_masks"].json_data if "uploaded_masks" in st.session_state else None,
display_toolbar= display_toolbar,
key="full_app",
)
final_img.putalpha(alpha_layer)
#try:
##print (type(st.session_state["uploaded_masks"]))
#if st.session_state["uploaded_masks"] != None:
#mask_expander.expander("Mask", expanded=True)
#mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
with masked_image_holder.container():
st.text("Masked Image Preview")
st.image(final_img)
#st.image(mask)
#if mask.mode == "RGBA":
#mask = mask.convert('RGBA')
#background = Image.new('RGBA', mask.size, (0, 0, 0))
#mask = Image.alpha_composite(background, mask)
#mask = mask.resize((width, height))
#except AttributeError:
#pass
with col3_img2img_layout:
result_tab = st.tabs(["Result"])
@ -645,7 +671,6 @@ def layout():
st.session_state["progress_bar_text"] = st.empty()
st.session_state["progress_bar"] = st.empty()
message = st.empty()
#if uploaded_images:
@ -666,14 +691,17 @@ def layout():
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
if uploaded_images:
image = Image.open(uploaded_images).convert('RGBA')
new_img = image.resize((width, height))
#img_array = np.array(image) # if you want to pass it to OpenCV
#image = Image.fromarray(image).convert('RGBA')
#new_img = image.resize((width, height))
###img_array = np.array(image) # if you want to pass it to OpenCV
#image_holder.image(new_img)
new_mask = None
if uploaded_masks:
mask = Image.open(uploaded_masks).convert('RGBA')
if st.session_state["uploaded_masks"]:
mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
new_mask = mask.resize((width, height))
#masked_image_holder.image(new_mask)
try:
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],

View File

@ -14,15 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
import collections.abc
#from webui_streamlit import st
import gfpgan
import hydralit as st
# streamlit imports
from streamlit import StopException, StreamlitAPIException
from streamlit.runtime.scriptrunner import script_run_context
#from streamlit.runtime.scriptrunner import script_run_context
#streamlit components section
from streamlit_server_state import server_state, server_state_lock
@ -35,7 +33,7 @@ import streamlit_nested_layout
import warnings
import json
import base64
import base64, cv2
import os, sys, re, random, datetime, time, math, glob, toml
import gc
from PIL import Image, ImageFont, ImageDraw, ImageFilter
@ -68,15 +66,31 @@ import piexif.helper
from tqdm import trange
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap
from abc import ABC, abstractmethod
#from abc import ABC, abstractmethod
from typing import Dict, Union
from io import BytesIO
from packaging import version
from uuid import uuid4
from pathlib import Path
from huggingface_hub import hf_hub_download
#import librosa
from logger import logger, set_logger_verbosity, quiesce_logger
#from logger import logger, set_logger_verbosity, quiesce_logger
#from loguru import logger
from nataili.inference.compvis.img2img import img2img
from nataili.model_manager import ModelManager
from nataili.inference.compvis.txt2img import txt2img
from nataili.util.cache import torch_gc
from nataili.util.logger import logger, set_logger_verbosity, quiesce_logger
try:
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
except ImportError as e:
logger.error("You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
"pip install realesrgan")
# Temp imports
#from basicsr.utils.registry import ARCH_REGISTRY
@ -84,14 +98,6 @@ from logger import logger, set_logger_verbosity, quiesce_logger
# end of imports
#---------------------------------------------------------------------------------------------------------------
# we make a log file where we store the logs
logger.add("logs/log_{time:MM-DD-YYYY!UTC}.log", rotation="8 MB", compression="zip", level='INFO') # Once the file is too old, it's rotated
logger.add(sys.stderr, diagnose=True)
logger.add(sys.stdout)
logger.enable("")
#
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
@ -112,6 +118,8 @@ mimetypes.add_type('application/javascript', '.js')
opt_C = 4
opt_f = 8
# The model manager loads and unloads the SD models and has features to download them or find their location
#model_manager = ModelManager()
def load_configs():
if not "defaults" in st.session_state:
@ -269,6 +277,33 @@ def make_grid(n_items=5, n_cols=5):
return cols
def merge(file1, file2, out, weight):
alpha = (weight)/100
if not(file1.endswith(".ckpt")):
file1 += ".ckpt"
if not(file2.endswith(".ckpt")):
file2 += ".ckpt"
if not(out.endswith(".ckpt")):
out += ".ckpt"
#Load Models
model_0 = torch.load(file1)
model_1 = torch.load(file2)
theta_0 = model_0['state_dict']
theta_1 = model_1['state_dict']
for key in theta_0.keys():
if 'model' in key and key in theta_1:
theta_0[key] = (alpha) * theta_0[key] + (1-alpha) * theta_1[key]
logger.info("RUNNING...\n(STAGE 2)")
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
torch.save(model_0, out)
def human_readable_size(size, decimal_places=3):
"""Return a human readable size from bytes."""
for unit in ['B','KB','MB','GB','TB']:
@ -282,6 +317,8 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m
CustomModel_available=False, custom_model="Stable Diffusion v1.5"):
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
#model_manager.init()
logger.info("Loading models.")
if "progress_bar_text" in st.session_state:
@ -1350,6 +1387,77 @@ def load_RealESRGAN(model_name: str):
return server_state['RealESRGAN']
#
class RealESRGANModel(nn.Module):
def __init__(self, model_path, tile=0, tile_pad=10, pre_pad=0, fp32=False):
super().__init__()
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
except ImportError as e:
logger.error(
"You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
"pip install realesrgan"
)
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=tile_pad, pre_pad=pre_pad, half=not fp32
)
def forward(self, image, outscale=4, convert_to_pil=True):
"""Upsample an image array or path.
Args:
image (Union[np.ndarray, str]): Either a np array or an image path. np array is assumed to be in RGB format,
and we convert it to BGR.
outscale (int, optional): Amount to upscale the image. Defaults to 4.
convert_to_pil (bool, optional): If True, return PIL image. Otherwise, return numpy array (BGR). Defaults to True.
Returns:
Union[np.ndarray, PIL.Image.Image]: An upsampled version of the input image.
"""
if isinstance(image, (str, Path)):
img = cv2.imread(image, cv2.IMREAD_UNCHANGED)
else:
img = image
img = (img * 255).round().astype("uint8")
img = img[:, :, ::-1]
image, _ = self.upsampler.enhance(img, outscale=outscale)
if convert_to_pil:
image = Image.fromarray(image[:, :, ::-1])
return image
@classmethod
def from_pretrained(cls, model_name_or_path="nateraw/real-esrgan"):
"""Initialize a pretrained Real-ESRGAN upsampler.
Args:
model_name_or_path (str, optional): The Hugging Face repo ID or path to local model. Defaults to 'nateraw/real-esrgan'.
Returns:
PipelineRealESRGAN: An instance of `PipelineRealESRGAN` instantiated from pretrained model.
"""
# reuploaded form official ones mentioned here:
# https://github.com/xinntao/Real-ESRGAN
if Path(model_name_or_path).exists():
file = model_name_or_path
else:
file = hf_hub_download(model_name_or_path, "RealESRGAN_x4plus.pth")
return cls(file)
def upsample_imagefolder(self, in_dir, out_dir, suffix="out", outfile_ext=".png"):
in_dir, out_dir = Path(in_dir), Path(out_dir)
if not in_dir.exists():
raise FileNotFoundError(f"Provided input directory {in_dir} does not exist")
out_dir.mkdir(exist_ok=True, parents=True)
image_paths = [x for x in in_dir.glob("*") if x.suffix.lower() in [".png", ".jpg", ".jpeg"]]
for image in image_paths:
im = self(str(image))
out_filepath = out_dir / (image.stem + suffix + outfile_ext)
im.save(out_filepath)
#
@retry(tries=5)
def load_LDSR(model_name="model", config="project", checking=False):
@ -1744,6 +1852,9 @@ def seed_to_int(s):
if s is None or s == '':
return random.randint(0, 2**32 - 1)
if ',' in s:
s = s.split(',')
if type(s) is list:
seed_list = []
for seed in s:
@ -1955,41 +2066,42 @@ def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, widt
filename_i = os.path.join(sample_path_i, filename)
if st.session_state['defaults'].general.save_metadata or write_info_files:
# toggles differ for txt2img vs. img2img:
offset = 0 if init_img is None else 2
toggles = []
if prompt_matrix:
toggles.append(0)
if normalize_prompt_weights:
toggles.append(1)
if init_img is not None:
if uses_loopback:
toggles.append(2)
if uses_random_seed_loopback:
toggles.append(3)
if save_individual_images:
toggles.append(2 + offset)
if save_grid:
toggles.append(3 + offset)
if sort_samples:
toggles.append(4 + offset)
if write_info_files:
toggles.append(5 + offset)
if use_GFPGAN:
toggles.append(6 + offset)
metadata = \
dict(
target="txt2img" if init_img is None else "img2img",
prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name,
ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale,
seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights, model_name=model_name)
# Not yet any use for these, but they bloat up the files:
# info_dict["init_img"] = init_img
# info_dict["init_mask"] = init_mask
if init_img is not None:
metadata["denoising_strength"] = str(denoising_strength)
metadata["resize_mode"] = resize_mode
if "defaults" in st.session_state:
if st.session_state['defaults'].general.save_metadata or write_info_files:
# toggles differ for txt2img vs. img2img:
offset = 0 if init_img is None else 2
toggles = []
if prompt_matrix:
toggles.append(0)
if normalize_prompt_weights:
toggles.append(1)
if init_img is not None:
if uses_loopback:
toggles.append(2)
if uses_random_seed_loopback:
toggles.append(3)
if save_individual_images:
toggles.append(2 + offset)
if save_grid:
toggles.append(3 + offset)
if sort_samples:
toggles.append(4 + offset)
if write_info_files:
toggles.append(5 + offset)
if use_GFPGAN:
toggles.append(6 + offset)
metadata = \
dict(
target="txt2img" if init_img is None else "img2img",
prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name,
ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale,
seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights, model_name=model_name)
# Not yet any use for these, but they bloat up the files:
# info_dict["init_img"] = init_img
# info_dict["init_mask"] = init_mask
if init_img is not None:
metadata["denoising_strength"] = str(denoising_strength)
metadata["resize_mode"] = resize_mode
if write_info_files:
with open(f"{filename_i}.yaml", "w", encoding="utf8") as f:
@ -2563,12 +2675,12 @@ def process_images(
#output_images.append(image)
#if simple_templating:
#grid_captions.append( captions[i] )
if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6
server_state["modelFS"].to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
if "defaults" in st.session_state:
if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6
server_state["modelFS"].to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
if len(run_images) > 1:
preview_image = image_grid(run_images, n_iter)

View File

@ -10,7 +10,6 @@ import time
import json
import torch
from diffusers import ModelMixin
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
@ -22,59 +21,39 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from torch import nn
from .upsampling import RealESRGANModel
from sd_utils import RealESRGANModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_spec_norm(wav, sr, n_mels=512, hop_length=704):
"""Obtain maximum value for each time-frame in Mel Spectrogram,
and normalize between 0 and 1
def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=1.0, smooth=0.0):
y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
Borrowed from lucid sonic dreams repo. In there, they programatically determine hop length
but I really didn't understand what was going on so I removed it and hard coded the output.
"""
# librosa.stft hardcoded defaults...
# n_fft defaults to 2048
# hop length is win_length // 4
# win_length defaults to n_fft
D = librosa.stft(y, n_fft=2048, hop_length=2048 // 4, win_length=2048)
# Generate Mel Spectrogram
spec_raw = librosa.feature.melspectrogram(y=wav, sr=sr, n_mels=n_mels, hop_length=hop_length)
# Extract percussive elements
D_harmonic, D_percussive = librosa.decompose.hpss(D, margin=margin)
y_percussive = librosa.istft(D_percussive, length=len(y))
# Obtain maximum value per time-frame
# Get normalized melspectrogram
spec_raw = librosa.feature.melspectrogram(y=y_percussive, sr=sr)
spec_max = np.amax(spec_raw, axis=0)
# Normalize all values between 0 and 1
spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max)
return spec_norm
# Resize cumsum of spec norm to our desired number of interpolation frames
x_norm = np.linspace(0, spec_norm.shape[-1], spec_norm.shape[-1])
y_norm = np.cumsum(spec_norm)
y_norm /= y_norm[-1]
x_resize = np.linspace(0, y_norm.shape[-1], int(duration*fps))
T = np.interp(x_resize, x_norm, y_norm)
def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=(1.0, 5.0)):
"""Get the array that will be used to determine how much to interpolate between images.
Normally, this is just a linspace between 0 and 1 for the number of frames to generate. In this case,
we want to use the amplitude of the audio to determine how much to interpolate between images.
So, here we:
1. Load the audio file
2. Split the audio into harmonic and percussive components
3. Get the normalized amplitude of the percussive component, resized to the number of frames
4. Get the cumulative sum of the amplitude array
5. Normalize the cumulative sum between 0 and 1
6. Return the array
I honestly have no clue what I'm doing here. Suggestions welcome.
"""
y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
wav_harmonic, wav_percussive = librosa.effects.hpss(y, margin=margin)
# Apparently n_mels is supposed to be input shape but I don't think it matters here?
frame_duration = int(sr / fps)
wav_norm = get_spec_norm(wav_percussive, sr, n_mels=512, hop_length=frame_duration)
amplitude_arr = np.resize(wav_norm, int(duration * fps))
T = np.cumsum(amplitude_arr)
T /= T[-1]
T[0] = 0.0
return T
# Apply smoothing
return T * (1 - smooth) + np.linspace(0.0, 1.0, T.shape[0]) * smooth
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
@ -130,7 +109,6 @@ def make_video_pyav(
frame = pil_to_tensor(Image.open(img)).unsqueeze(0)
frames = frame if frames is None else torch.cat([frames, frame])
else:
frames = frames_or_frame_dir
# TCHW -> THWC
@ -208,6 +186,16 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@ -251,6 +239,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
@ -259,12 +249,13 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
text_embeddings: Optional[torch.FloatTensor] = None,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
prompt (`str` or `List[str]`, *optional*, defaults to `None`):
The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
@ -278,6 +269,11 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
@ -300,8 +296,10 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
text_embeddings(`torch.FloatTensor`, *optional*):
Pre-generated text embeddings.
text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):
Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
`prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
the supplied `prompt`.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
@ -340,7 +338,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
print(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
@ -349,21 +347,51 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
else:
batch_size = text_embeddings.shape[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
# HACK - Not setting text_input_ids here when walking, so hard coding to max length of tokenizer
# TODO - Determine if this is OK to do
# max_length = text_input_ids.shape[-1]
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = self.tokenizer.model_max_length
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
@ -374,19 +402,20 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=latents_device,
dtype=text_embeddings.dtype,
)
if self.device.type == "mps":
# randn does not exist on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(latents_device)
latents = latents.to(self.device)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
@ -431,12 +460,19 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)
@ -449,16 +485,9 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size):
embeds_a = self.embed_text(prompt_a)
embeds_b = self.embed_text(prompt_b)
latents_a = torch.randn(
noise_shape,
device=self.device,
generator=torch.Generator(device=self.device).manual_seed(seed_a),
)
latents_b = torch.randn(
noise_shape,
device=self.device,
generator=torch.Generator(device=self.device).manual_seed(seed_b),
)
latents_a = self.init_noise(seed_a, noise_shape)
latents_b = self.init_noise(seed_b, noise_shape)
batch_idx = 0
embeds_batch, noise_batch = None, None
@ -477,7 +506,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
torch.cuda.empty_cache()
embeds_batch, noise_batch = None, None
def generate_interpolation_clip(
def make_clip_frames(
self,
prompt_a: str,
prompt_b: str,
@ -530,7 +559,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
eta=eta,
num_inference_steps=num_inference_steps,
output_type="pil" if not upsample else "numpy",
)["sample"]
)["images"]
for image in outputs:
frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index)
@ -557,6 +586,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
resume: Optional[bool] = False,
audio_filepath: str = None,
audio_start_sec: Optional[Union[int, float]] = None,
margin: Optional[float] = 1.0,
smooth: Optional[float] = 0.0,
):
"""Generate a video from a sequence of prompts and seeds. Optionally, add audio to the
video to interpolate to the intensity of the audio.
@ -603,13 +634,17 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
Optional path to an audio file to influence the interpolation rate.
audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0):
Global start time of the provided audio_filepath.
margin (Optional[float], *optional*, defaults to 1.0):
Margin from librosa hpss to use for audio interpolation.
smooth (Optional[float], *optional*, defaults to 0.0):
Smoothness of the audio interpolation. 1.0 means linear interpolation.
This function will create sub directories for each prompt and seed pair.
For example, if you provide the following prompts and seeds:
```
prompts = ['a', 'b', 'c']
prompts = ['a dog', 'a cat', 'a bird']
seeds = [1, 2, 3]
num_interpolation_steps = 5
output_dir = 'output_dir'
@ -722,7 +757,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
audio_duration = num_step / fps
self.generate_interpolation_clip(
self.make_clip_frames(
prompt_a,
prompt_b,
seed_a,
@ -742,7 +777,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
offset=audio_offset,
duration=audio_duration,
fps=fps,
margin=(1.0, 5.0),
margin=margin,
smooth=smooth,
)
if audio_filepath
else None,
@ -783,6 +819,23 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
return embed
def init_noise(self, seed, noise_shape):
"""Helper to initialize noise"""
# randn does not exist on mps, so we create noise on CPU here and move it to the device after initialization
if self.device.type == "mps":
noise = torch.randn(
noise_shape,
device='cpu',
generator=torch.Generator(device='cpu').manual_seed(seed),
).to(self.device)
else:
noise = torch.randn(
noise_shape,
device=self.device,
generator=torch.Generator(device=self.device).manual_seed(seed),
)
return noise
@classmethod
def from_pretrained(cls, *args, tiled=False, **kwargs):
"""Same as diffusers `from_pretrained` but with tiled option, which makes images tilable"""
@ -799,15 +852,6 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
patch_conv(padding_mode="circular")
return super().from_pretrained(*args, **kwargs)
class NoCheck(ModelMixin):
"""Can be used in place of safety checker. Use responsibly and at your own risk."""
def __init__(self):
super().__init__()
self.register_parameter(name="asdf", param=torch.nn.Parameter(torch.randn(3)))
def forward(self, images=None, **kwargs):
return images, [False]
pipeline = super().from_pretrained(*args, **kwargs)
pipeline.tiled = tiled
return pipeline

View File

@ -410,7 +410,7 @@ def layout():
sygil_suggestions.suggestion_area(placeholder)
# creating the page layout using columns
col1, col2, col3 = st.columns([1,2,1], gap="large")
col1, col2, col3 = st.columns([2,5,2], gap="large")
with col1:
width = st.slider("Width:", min_value=st.session_state['defaults'].txt2img.width.min_value, max_value=st.session_state['defaults'].txt2img.width.max_value,

File diff suppressed because it is too large Load Diff

View File

@ -125,8 +125,11 @@ def layout():
{'id': 'Stable Diffusion', 'label': 'Stable Diffusion', 'icon': 'bi bi-grid-1x2-fill'},
{'id': 'Textual Inversion', 'label': 'Textual Inversion', 'icon': 'bi bi-lightbulb-fill'},
{'id': 'Model Manager', 'label': 'Model Manager', 'icon': 'bi bi-cloud-arrow-down-fill'},
#{'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
{'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
#{'id': 'Barfi/BaklavaJS', 'label': 'Barfi/BaklavaJS', 'icon': 'bi bi-diagram-3-fill'},
#{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
]},
{'id': 'Settings', 'label': 'Settings', 'icon': 'bi bi-gear-fill'},
#{'icon': "fa-solid fa-radar",'label':"Dropdown1", 'submenu':[
# {'id':' subid11','icon': "fa fa-paperclip", 'label':"Sub-item 1"},{'id':'subid12','icon': "💀", 'label':"Sub-item 2"},{'id':'subid13','icon': "fa fa-database", 'label':"Sub-item 3"}]},
@ -172,6 +175,10 @@ def layout():
#horizontal_orientation=False,
#override_theme={'txc_inactive': 'white','menu_background':'#111', 'stVerticalBlock': '#111','txc_active':'yellow','option_active':'blue'})
#
#if menu_id == "Home":
#st.info("Under Construction. :construction_worker:")
if menu_id == "Stable Diffusion":
# set the page url and title
#st.experimental_set_query_params(page='stable-diffusion')
@ -227,6 +234,11 @@ def layout():
from APIServer import layout
layout()
#elif menu_id == 'Barfi/BaklavaJS':
#set_page_title("Barfi/BaklavaJS - Stable Diffusion Playground")
#from barfi_baklavajs import layout
#layout()
elif menu_id == 'Settings':
set_page_title("Settings - Stable Diffusion Playground")