Added basic image editor for the img2img tab. (#1613)

Added an image editor for img2img, for now its pretty basic and has some
bugs but it should do the job better than having to create a mask using
an external software and then importing it to the UI every time.
This commit is contained in:
Alejandro Gil 2022-10-28 12:41:21 -07:00 committed by GitHub
commit e9a4e46f66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 2028 additions and 910 deletions

View File

@ -230,7 +230,8 @@ img2img:
step: 0.01 step: 0.01
# 0: Keep masked area # 0: Keep masked area
# 1: Regenerate only masked area # 1: Regenerate only masked area
mask_mode: 0 mask_mode: 1
noise_mode: "Matched Noise"
mask_restore: False mask_restore: False
# 0: Just resize # 0: Just resize
# 1: Crop and resize # 1: Crop and resize

View File

@ -33,6 +33,7 @@ from ldm.models.diffusion.plms import PLMSSampler
# streamlit components # streamlit components
from custom_components import sygil_suggestions from custom_components import sygil_suggestions
from streamlit_drawable_canvas import st_canvas
# Temp imports # Temp imports
@ -381,7 +382,7 @@ def layout():
# creating the page layout using columns # creating the page layout using columns
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="medium") col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([2,4,4], gap="medium")
with col1_img2img_layout: with col1_img2img_layout:
# If we have custom models available on the "models/custom" # If we have custom models available on the "models/custom"
@ -426,7 +427,7 @@ def layout():
mask_expander = st.empty() mask_expander = st.empty()
with mask_expander.expander("Mask"): with mask_expander.expander("Mask"):
mask_mode_list = ["Mask", "Inverted mask", "Image alpha"] mask_mode_list = ["Mask", "Inverted mask", "Image alpha"]
mask_mode = st.selectbox("Mask Mode", mask_mode_list, mask_mode = st.selectbox("Mask Mode", mask_mode_list, index=st.session_state["defaults"].img2img.mask_mode,
help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\ help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\
\"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent." \"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
) )
@ -434,15 +435,32 @@ def layout():
noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"] noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
noise_mode = st.selectbox( noise_mode = st.selectbox("Noise Mode", noise_mode_list, index=noise_mode_list.index(st.session_state['defaults'].img2img.noise_mode), help="")
"Noise Mode", noise_mode_list, #noise_mode = noise_mode_list.index(noise_mode)
help=""
)
noise_mode = noise_mode_list.index(noise_mode)
find_noise_steps = st.number_input("Find Noise Steps", value=st.session_state['defaults'].img2img.find_noise_steps.value, find_noise_steps = st.number_input("Find Noise Steps", value=st.session_state['defaults'].img2img.find_noise_steps.value,
min_value=st.session_state['defaults'].img2img.find_noise_steps.min_value, min_value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
step=st.session_state['defaults'].img2img.find_noise_steps.step) step=st.session_state['defaults'].img2img.find_noise_steps.step)
# Specify canvas parameters in application
drawing_mode = st.selectbox(
"Drawing tool:",
(
"freedraw",
"transform",
#"line",
"rect",
"circle",
#"polygon",
),
)
stroke_width = st.slider("Stroke width: ", 1, 100, 50)
stroke_color = st.color_picker("Stroke color hex: ", value="#EEEEEE")
bg_color = st.color_picker("Background color hex: ", "#7B6E6E")
display_toolbar = st.checkbox("Display toolbar", True)
#realtime_update = st.checkbox("Update in realtime", True)
with st.expander("Batch Options"): with st.expander("Batch Options"):
st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].img2img.batch_count.value, st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].img2img.batch_count.value,
help="How many iterations or batches of images to generate in total.") help="How many iterations or batches of images to generate in total.")
@ -583,55 +601,63 @@ def layout():
editor_image = st.empty() editor_image = st.empty()
st.session_state["editor_image"] = editor_image st.session_state["editor_image"] = editor_image
st.form_submit_button("Refresh")
#if "canvas" not in st.session_state:
st.session_state["canvas"] = st.empty()
masked_image_holder = st.empty() masked_image_holder = st.empty()
image_holder = st.empty() image_holder = st.empty()
st.form_submit_button("Refresh")
uploaded_images = st.file_uploader( uploaded_images = st.file_uploader(
"Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'], "Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
help="Upload an image which will be used for the image to image generation.", help="Upload an image which will be used for the image to image generation.",
) )
if uploaded_images: if uploaded_images:
image = Image.open(uploaded_images).convert('RGBA') image = Image.open(uploaded_images).convert('RGB')
new_img = image.resize((width, height)) new_img = image.resize((width, height))
image_holder.image(new_img) #image_holder.image(new_img)
mask_holder = st.empty() #mask_holder = st.empty()
uploaded_masks = st.file_uploader( #uploaded_masks = st.file_uploader(
"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'], #"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
help="Upload an mask image which will be used for masking the image to image generation.", #help="Upload an mask image which will be used for masking the image to image generation.",
#)
#
# Create a canvas component
with st.session_state["canvas"]:
st.session_state["uploaded_masks"] = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color=bg_color,
background_image=image if uploaded_images else None,
update_streamlit=True,
width=width,
height=height,
drawing_mode=drawing_mode,
initial_drawing=st.session_state["uploaded_masks"].json_data if "uploaded_masks" in st.session_state else None,
display_toolbar= display_toolbar,
key="full_app",
) )
if uploaded_masks:
mask_expander.expander("Mask", expanded=True)
mask = Image.open(uploaded_masks)
if mask.mode == "RGBA":
mask = mask.convert('RGBA')
background = Image.new('RGBA', mask.size, (0, 0, 0))
mask = Image.alpha_composite(background, mask)
mask = mask.resize((width, height))
mask_holder.image(mask)
if uploaded_images and uploaded_masks: #try:
if mask_mode != 2: ##print (type(st.session_state["uploaded_masks"]))
final_img = new_img.copy() #if st.session_state["uploaded_masks"] != None:
alpha_layer = mask.convert('L') #mask_expander.expander("Mask", expanded=True)
strength = st.session_state["denoising_strength"] #mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
if mask_mode == 0:
alpha_layer = ImageOps.invert(alpha_layer)
alpha_layer = alpha_layer.point(lambda a: a * strength)
alpha_layer = ImageOps.invert(alpha_layer)
elif mask_mode == 1:
alpha_layer = alpha_layer.point(lambda a: a * strength)
alpha_layer = ImageOps.invert(alpha_layer)
final_img.putalpha(alpha_layer) #st.image(mask)
with masked_image_holder.container():
st.text("Masked Image Preview")
st.image(final_img)
#if mask.mode == "RGBA":
#mask = mask.convert('RGBA')
#background = Image.new('RGBA', mask.size, (0, 0, 0))
#mask = Image.alpha_composite(background, mask)
#mask = mask.resize((width, height))
#except AttributeError:
#pass
with col3_img2img_layout: with col3_img2img_layout:
result_tab = st.tabs(["Result"]) result_tab = st.tabs(["Result"])
@ -645,7 +671,6 @@ def layout():
st.session_state["progress_bar_text"] = st.empty() st.session_state["progress_bar_text"] = st.empty()
st.session_state["progress_bar"] = st.empty() st.session_state["progress_bar"] = st.empty()
message = st.empty() message = st.empty()
#if uploaded_images: #if uploaded_images:
@ -666,14 +691,17 @@ def layout():
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"]) CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
if uploaded_images: if uploaded_images:
image = Image.open(uploaded_images).convert('RGBA') #image = Image.fromarray(image).convert('RGBA')
new_img = image.resize((width, height)) #new_img = image.resize((width, height))
#img_array = np.array(image) # if you want to pass it to OpenCV ###img_array = np.array(image) # if you want to pass it to OpenCV
#image_holder.image(new_img)
new_mask = None new_mask = None
if uploaded_masks:
mask = Image.open(uploaded_masks).convert('RGBA') if st.session_state["uploaded_masks"]:
mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
new_mask = mask.resize((width, height)) new_mask = mask.resize((width, height))
#masked_image_holder.image(new_mask)
try: try:
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode, output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"], mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],

View File

@ -14,15 +14,13 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils. # base webui import and utils.
import collections.abc
#from webui_streamlit import st #from webui_streamlit import st
import gfpgan import gfpgan
import hydralit as st import hydralit as st
# streamlit imports # streamlit imports
from streamlit import StopException, StreamlitAPIException from streamlit import StopException, StreamlitAPIException
from streamlit.runtime.scriptrunner import script_run_context #from streamlit.runtime.scriptrunner import script_run_context
#streamlit components section #streamlit components section
from streamlit_server_state import server_state, server_state_lock from streamlit_server_state import server_state, server_state_lock
@ -35,7 +33,7 @@ import streamlit_nested_layout
import warnings import warnings
import json import json
import base64 import base64, cv2
import os, sys, re, random, datetime, time, math, glob, toml import os, sys, re, random, datetime, time, math, glob, toml
import gc import gc
from PIL import Image, ImageFont, ImageDraw, ImageFilter from PIL import Image, ImageFont, ImageDraw, ImageFilter
@ -68,15 +66,31 @@ import piexif.helper
from tqdm import trange from tqdm import trange
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap from ldm.util import ismap
from abc import ABC, abstractmethod #from abc import ABC, abstractmethod
from typing import Dict, Union from typing import Dict, Union
from io import BytesIO from io import BytesIO
from packaging import version from packaging import version
from uuid import uuid4 from uuid import uuid4
from pathlib import Path
from huggingface_hub import hf_hub_download
#import librosa #import librosa
from logger import logger, set_logger_verbosity, quiesce_logger #from logger import logger, set_logger_verbosity, quiesce_logger
#from loguru import logger #from loguru import logger
from nataili.inference.compvis.img2img import img2img
from nataili.model_manager import ModelManager
from nataili.inference.compvis.txt2img import txt2img
from nataili.util.cache import torch_gc
from nataili.util.logger import logger, set_logger_verbosity, quiesce_logger
try:
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
except ImportError as e:
logger.error("You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
"pip install realesrgan")
# Temp imports # Temp imports
#from basicsr.utils.registry import ARCH_REGISTRY #from basicsr.utils.registry import ARCH_REGISTRY
@ -84,14 +98,6 @@ from logger import logger, set_logger_verbosity, quiesce_logger
# end of imports # end of imports
#--------------------------------------------------------------------------------------------------------------- #---------------------------------------------------------------------------------------------------------------
# we make a log file where we store the logs
logger.add("logs/log_{time:MM-DD-YYYY!UTC}.log", rotation="8 MB", compression="zip", level='INFO') # Once the file is too old, it's rotated
logger.add(sys.stderr, diagnose=True)
logger.add(sys.stdout)
logger.enable("")
#
try: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging from transformers import logging
@ -112,6 +118,8 @@ mimetypes.add_type('application/javascript', '.js')
opt_C = 4 opt_C = 4
opt_f = 8 opt_f = 8
# The model manager loads and unloads the SD models and has features to download them or find their location
#model_manager = ModelManager()
def load_configs(): def load_configs():
if not "defaults" in st.session_state: if not "defaults" in st.session_state:
@ -269,6 +277,33 @@ def make_grid(n_items=5, n_cols=5):
return cols return cols
def merge(file1, file2, out, weight):
alpha = (weight)/100
if not(file1.endswith(".ckpt")):
file1 += ".ckpt"
if not(file2.endswith(".ckpt")):
file2 += ".ckpt"
if not(out.endswith(".ckpt")):
out += ".ckpt"
#Load Models
model_0 = torch.load(file1)
model_1 = torch.load(file2)
theta_0 = model_0['state_dict']
theta_1 = model_1['state_dict']
for key in theta_0.keys():
if 'model' in key and key in theta_1:
theta_0[key] = (alpha) * theta_0[key] + (1-alpha) * theta_1[key]
logger.info("RUNNING...\n(STAGE 2)")
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
torch.save(model_0, out)
def human_readable_size(size, decimal_places=3): def human_readable_size(size, decimal_places=3):
"""Return a human readable size from bytes.""" """Return a human readable size from bytes."""
for unit in ['B','KB','MB','GB','TB']: for unit in ['B','KB','MB','GB','TB']:
@ -282,6 +317,8 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m
CustomModel_available=False, custom_model="Stable Diffusion v1.5"): CustomModel_available=False, custom_model="Stable Diffusion v1.5"):
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
#model_manager.init()
logger.info("Loading models.") logger.info("Loading models.")
if "progress_bar_text" in st.session_state: if "progress_bar_text" in st.session_state:
@ -1350,6 +1387,77 @@ def load_RealESRGAN(model_name: str):
return server_state['RealESRGAN'] return server_state['RealESRGAN']
#
class RealESRGANModel(nn.Module):
def __init__(self, model_path, tile=0, tile_pad=10, pre_pad=0, fp32=False):
super().__init__()
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
except ImportError as e:
logger.error(
"You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
"pip install realesrgan"
)
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
self.upsampler = RealESRGANer(
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=tile_pad, pre_pad=pre_pad, half=not fp32
)
def forward(self, image, outscale=4, convert_to_pil=True):
"""Upsample an image array or path.
Args:
image (Union[np.ndarray, str]): Either a np array or an image path. np array is assumed to be in RGB format,
and we convert it to BGR.
outscale (int, optional): Amount to upscale the image. Defaults to 4.
convert_to_pil (bool, optional): If True, return PIL image. Otherwise, return numpy array (BGR). Defaults to True.
Returns:
Union[np.ndarray, PIL.Image.Image]: An upsampled version of the input image.
"""
if isinstance(image, (str, Path)):
img = cv2.imread(image, cv2.IMREAD_UNCHANGED)
else:
img = image
img = (img * 255).round().astype("uint8")
img = img[:, :, ::-1]
image, _ = self.upsampler.enhance(img, outscale=outscale)
if convert_to_pil:
image = Image.fromarray(image[:, :, ::-1])
return image
@classmethod
def from_pretrained(cls, model_name_or_path="nateraw/real-esrgan"):
"""Initialize a pretrained Real-ESRGAN upsampler.
Args:
model_name_or_path (str, optional): The Hugging Face repo ID or path to local model. Defaults to 'nateraw/real-esrgan'.
Returns:
PipelineRealESRGAN: An instance of `PipelineRealESRGAN` instantiated from pretrained model.
"""
# reuploaded form official ones mentioned here:
# https://github.com/xinntao/Real-ESRGAN
if Path(model_name_or_path).exists():
file = model_name_or_path
else:
file = hf_hub_download(model_name_or_path, "RealESRGAN_x4plus.pth")
return cls(file)
def upsample_imagefolder(self, in_dir, out_dir, suffix="out", outfile_ext=".png"):
in_dir, out_dir = Path(in_dir), Path(out_dir)
if not in_dir.exists():
raise FileNotFoundError(f"Provided input directory {in_dir} does not exist")
out_dir.mkdir(exist_ok=True, parents=True)
image_paths = [x for x in in_dir.glob("*") if x.suffix.lower() in [".png", ".jpg", ".jpeg"]]
for image in image_paths:
im = self(str(image))
out_filepath = out_dir / (image.stem + suffix + outfile_ext)
im.save(out_filepath)
# #
@retry(tries=5) @retry(tries=5)
def load_LDSR(model_name="model", config="project", checking=False): def load_LDSR(model_name="model", config="project", checking=False):
@ -1744,6 +1852,9 @@ def seed_to_int(s):
if s is None or s == '': if s is None or s == '':
return random.randint(0, 2**32 - 1) return random.randint(0, 2**32 - 1)
if ',' in s:
s = s.split(',')
if type(s) is list: if type(s) is list:
seed_list = [] seed_list = []
for seed in s: for seed in s:
@ -1955,6 +2066,7 @@ def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, widt
filename_i = os.path.join(sample_path_i, filename) filename_i = os.path.join(sample_path_i, filename)
if "defaults" in st.session_state:
if st.session_state['defaults'].general.save_metadata or write_info_files: if st.session_state['defaults'].general.save_metadata or write_info_files:
# toggles differ for txt2img vs. img2img: # toggles differ for txt2img vs. img2img:
offset = 0 if init_img is None else 2 offset = 0 if init_img is None else 2
@ -2563,7 +2675,7 @@ def process_images(
#output_images.append(image) #output_images.append(image)
#if simple_templating: #if simple_templating:
#grid_captions.append( captions[i] ) #grid_captions.append( captions[i] )
if "defaults" in st.session_state:
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6 mem = torch.cuda.memory_allocated()/1e6
server_state["modelFS"].to("cpu") server_state["modelFS"].to("cpu")

View File

@ -10,7 +10,6 @@ import time
import json import json
import torch import torch
from diffusers import ModelMixin
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
@ -22,59 +21,39 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from torch import nn from torch import nn
from .upsampling import RealESRGANModel from sd_utils import RealESRGANModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_spec_norm(wav, sr, n_mels=512, hop_length=704): def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=1.0, smooth=0.0):
"""Obtain maximum value for each time-frame in Mel Spectrogram, y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
and normalize between 0 and 1
Borrowed from lucid sonic dreams repo. In there, they programatically determine hop length # librosa.stft hardcoded defaults...
but I really didn't understand what was going on so I removed it and hard coded the output. # n_fft defaults to 2048
""" # hop length is win_length // 4
# win_length defaults to n_fft
D = librosa.stft(y, n_fft=2048, hop_length=2048 // 4, win_length=2048)
# Generate Mel Spectrogram # Extract percussive elements
spec_raw = librosa.feature.melspectrogram(y=wav, sr=sr, n_mels=n_mels, hop_length=hop_length) D_harmonic, D_percussive = librosa.decompose.hpss(D, margin=margin)
y_percussive = librosa.istft(D_percussive, length=len(y))
# Obtain maximum value per time-frame # Get normalized melspectrogram
spec_raw = librosa.feature.melspectrogram(y=y_percussive, sr=sr)
spec_max = np.amax(spec_raw, axis=0) spec_max = np.amax(spec_raw, axis=0)
# Normalize all values between 0 and 1
spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max) spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max)
return spec_norm # Resize cumsum of spec norm to our desired number of interpolation frames
x_norm = np.linspace(0, spec_norm.shape[-1], spec_norm.shape[-1])
y_norm = np.cumsum(spec_norm)
y_norm /= y_norm[-1]
x_resize = np.linspace(0, y_norm.shape[-1], int(duration*fps))
T = np.interp(x_resize, x_norm, y_norm)
def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=(1.0, 5.0)): # Apply smoothing
"""Get the array that will be used to determine how much to interpolate between images. return T * (1 - smooth) + np.linspace(0.0, 1.0, T.shape[0]) * smooth
Normally, this is just a linspace between 0 and 1 for the number of frames to generate. In this case,
we want to use the amplitude of the audio to determine how much to interpolate between images.
So, here we:
1. Load the audio file
2. Split the audio into harmonic and percussive components
3. Get the normalized amplitude of the percussive component, resized to the number of frames
4. Get the cumulative sum of the amplitude array
5. Normalize the cumulative sum between 0 and 1
6. Return the array
I honestly have no clue what I'm doing here. Suggestions welcome.
"""
y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
wav_harmonic, wav_percussive = librosa.effects.hpss(y, margin=margin)
# Apparently n_mels is supposed to be input shape but I don't think it matters here?
frame_duration = int(sr / fps)
wav_norm = get_spec_norm(wav_percussive, sr, n_mels=512, hop_length=frame_duration)
amplitude_arr = np.resize(wav_norm, int(duration * fps))
T = np.cumsum(amplitude_arr)
T /= T[-1]
T[0] = 0.0
return T
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
@ -130,7 +109,6 @@ def make_video_pyav(
frame = pil_to_tensor(Image.open(img)).unsqueeze(0) frame = pil_to_tensor(Image.open(img)).unsqueeze(0)
frames = frame if frames is None else torch.cat([frames, frame]) frames = frame if frames is None else torch.cat([frames, frame])
else: else:
frames = frames_or_frame_dir frames = frames_or_frame_dir
# TCHW -> THWC # TCHW -> THWC
@ -208,6 +186,16 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
@ -251,6 +239,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
width: int = 512, width: int = 512,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
@ -259,12 +249,13 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
text_embeddings: Optional[torch.FloatTensor] = None, text_embeddings: Optional[torch.FloatTensor] = None,
**kwargs,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*, defaults to `None`):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
height (`int`, *optional*, defaults to 512): height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512): width (`int`, *optional*, defaults to 512):
@ -278,6 +269,11 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others. [`schedulers.DDIMScheduler`], will be ignored for others.
@ -300,8 +296,10 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
text_embeddings(`torch.FloatTensor`, *optional*): text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):
Pre-generated text embeddings. Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
`prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
the supplied `prompt`.
Returns: Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
@ -340,7 +338,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning( print(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
@ -349,21 +347,51 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
else: else:
batch_size = text_embeddings.shape[0] batch_size = text_embeddings.shape[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
# HACK - Not setting text_input_ids here when walking, so hard coding to max length of tokenizer uncond_tokens: List[str]
# TODO - Determine if this is OK to do if negative_prompt is None:
# max_length = text_input_ids.shape[-1] uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = self.tokenizer.model_max_length max_length = self.tokenizer.model_max_length
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
) )
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
@ -374,19 +402,20 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device # Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation. # for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`. # However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
latents = torch.randn( if self.device.type == "mps":
latents_shape, # randn does not exist on mps
generator=generator, latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
device=latents_device, self.device
dtype=text_embeddings.dtype,
) )
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
else: else:
if latents.shape != latents_shape: if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(latents_device) latents = latents.to(self.device)
# set timesteps # set timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
@ -431,12 +460,19 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
) )
else:
has_nsfw_concept = None
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
@ -449,16 +485,9 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size): def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size):
embeds_a = self.embed_text(prompt_a) embeds_a = self.embed_text(prompt_a)
embeds_b = self.embed_text(prompt_b) embeds_b = self.embed_text(prompt_b)
latents_a = torch.randn(
noise_shape, latents_a = self.init_noise(seed_a, noise_shape)
device=self.device, latents_b = self.init_noise(seed_b, noise_shape)
generator=torch.Generator(device=self.device).manual_seed(seed_a),
)
latents_b = torch.randn(
noise_shape,
device=self.device,
generator=torch.Generator(device=self.device).manual_seed(seed_b),
)
batch_idx = 0 batch_idx = 0
embeds_batch, noise_batch = None, None embeds_batch, noise_batch = None, None
@ -477,7 +506,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
torch.cuda.empty_cache() torch.cuda.empty_cache()
embeds_batch, noise_batch = None, None embeds_batch, noise_batch = None, None
def generate_interpolation_clip( def make_clip_frames(
self, self,
prompt_a: str, prompt_a: str,
prompt_b: str, prompt_b: str,
@ -530,7 +559,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
eta=eta, eta=eta,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
output_type="pil" if not upsample else "numpy", output_type="pil" if not upsample else "numpy",
)["sample"] )["images"]
for image in outputs: for image in outputs:
frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index) frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index)
@ -557,6 +586,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
resume: Optional[bool] = False, resume: Optional[bool] = False,
audio_filepath: str = None, audio_filepath: str = None,
audio_start_sec: Optional[Union[int, float]] = None, audio_start_sec: Optional[Union[int, float]] = None,
margin: Optional[float] = 1.0,
smooth: Optional[float] = 0.0,
): ):
"""Generate a video from a sequence of prompts and seeds. Optionally, add audio to the """Generate a video from a sequence of prompts and seeds. Optionally, add audio to the
video to interpolate to the intensity of the audio. video to interpolate to the intensity of the audio.
@ -603,13 +634,17 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
Optional path to an audio file to influence the interpolation rate. Optional path to an audio file to influence the interpolation rate.
audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0): audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0):
Global start time of the provided audio_filepath. Global start time of the provided audio_filepath.
margin (Optional[float], *optional*, defaults to 1.0):
Margin from librosa hpss to use for audio interpolation.
smooth (Optional[float], *optional*, defaults to 0.0):
Smoothness of the audio interpolation. 1.0 means linear interpolation.
This function will create sub directories for each prompt and seed pair. This function will create sub directories for each prompt and seed pair.
For example, if you provide the following prompts and seeds: For example, if you provide the following prompts and seeds:
``` ```
prompts = ['a', 'b', 'c'] prompts = ['a dog', 'a cat', 'a bird']
seeds = [1, 2, 3] seeds = [1, 2, 3]
num_interpolation_steps = 5 num_interpolation_steps = 5
output_dir = 'output_dir' output_dir = 'output_dir'
@ -722,7 +757,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
audio_duration = num_step / fps audio_duration = num_step / fps
self.generate_interpolation_clip( self.make_clip_frames(
prompt_a, prompt_a,
prompt_b, prompt_b,
seed_a, seed_a,
@ -742,7 +777,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
offset=audio_offset, offset=audio_offset,
duration=audio_duration, duration=audio_duration,
fps=fps, fps=fps,
margin=(1.0, 5.0), margin=margin,
smooth=smooth,
) )
if audio_filepath if audio_filepath
else None, else None,
@ -783,6 +819,23 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
embed = self.text_encoder(text_input.input_ids.to(self.device))[0] embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
return embed return embed
def init_noise(self, seed, noise_shape):
"""Helper to initialize noise"""
# randn does not exist on mps, so we create noise on CPU here and move it to the device after initialization
if self.device.type == "mps":
noise = torch.randn(
noise_shape,
device='cpu',
generator=torch.Generator(device='cpu').manual_seed(seed),
).to(self.device)
else:
noise = torch.randn(
noise_shape,
device=self.device,
generator=torch.Generator(device=self.device).manual_seed(seed),
)
return noise
@classmethod @classmethod
def from_pretrained(cls, *args, tiled=False, **kwargs): def from_pretrained(cls, *args, tiled=False, **kwargs):
"""Same as diffusers `from_pretrained` but with tiled option, which makes images tilable""" """Same as diffusers `from_pretrained` but with tiled option, which makes images tilable"""
@ -799,15 +852,6 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
patch_conv(padding_mode="circular") patch_conv(padding_mode="circular")
return super().from_pretrained(*args, **kwargs) pipeline = super().from_pretrained(*args, **kwargs)
pipeline.tiled = tiled
return pipeline
class NoCheck(ModelMixin):
"""Can be used in place of safety checker. Use responsibly and at your own risk."""
def __init__(self):
super().__init__()
self.register_parameter(name="asdf", param=torch.nn.Parameter(torch.randn(3)))
def forward(self, images=None, **kwargs):
return images, [False]

View File

@ -410,7 +410,7 @@ def layout():
sygil_suggestions.suggestion_area(placeholder) sygil_suggestions.suggestion_area(placeholder)
# creating the page layout using columns # creating the page layout using columns
col1, col2, col3 = st.columns([1,2,1], gap="large") col1, col2, col3 = st.columns([2,5,2], gap="large")
with col1: with col1:
width = st.slider("Width:", min_value=st.session_state['defaults'].txt2img.width.min_value, max_value=st.session_state['defaults'].txt2img.width.max_value, width = st.slider("Width:", min_value=st.session_state['defaults'].txt2img.width.min_value, max_value=st.session_state['defaults'].txt2img.width.max_value,

File diff suppressed because it is too large Load Diff

View File

@ -125,8 +125,11 @@ def layout():
{'id': 'Stable Diffusion', 'label': 'Stable Diffusion', 'icon': 'bi bi-grid-1x2-fill'}, {'id': 'Stable Diffusion', 'label': 'Stable Diffusion', 'icon': 'bi bi-grid-1x2-fill'},
{'id': 'Textual Inversion', 'label': 'Textual Inversion', 'icon': 'bi bi-lightbulb-fill'}, {'id': 'Textual Inversion', 'label': 'Textual Inversion', 'icon': 'bi bi-lightbulb-fill'},
{'id': 'Model Manager', 'label': 'Model Manager', 'icon': 'bi bi-cloud-arrow-down-fill'}, {'id': 'Model Manager', 'label': 'Model Manager', 'icon': 'bi bi-cloud-arrow-down-fill'},
#{'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[ {'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'}, {'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
#{'id': 'Barfi/BaklavaJS', 'label': 'Barfi/BaklavaJS', 'icon': 'bi bi-diagram-3-fill'},
#{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
]},
{'id': 'Settings', 'label': 'Settings', 'icon': 'bi bi-gear-fill'}, {'id': 'Settings', 'label': 'Settings', 'icon': 'bi bi-gear-fill'},
#{'icon': "fa-solid fa-radar",'label':"Dropdown1", 'submenu':[ #{'icon': "fa-solid fa-radar",'label':"Dropdown1", 'submenu':[
# {'id':' subid11','icon': "fa fa-paperclip", 'label':"Sub-item 1"},{'id':'subid12','icon': "💀", 'label':"Sub-item 2"},{'id':'subid13','icon': "fa fa-database", 'label':"Sub-item 3"}]}, # {'id':' subid11','icon': "fa fa-paperclip", 'label':"Sub-item 1"},{'id':'subid12','icon': "💀", 'label':"Sub-item 2"},{'id':'subid13','icon': "fa fa-database", 'label':"Sub-item 3"}]},
@ -172,6 +175,10 @@ def layout():
#horizontal_orientation=False, #horizontal_orientation=False,
#override_theme={'txc_inactive': 'white','menu_background':'#111', 'stVerticalBlock': '#111','txc_active':'yellow','option_active':'blue'}) #override_theme={'txc_inactive': 'white','menu_background':'#111', 'stVerticalBlock': '#111','txc_active':'yellow','option_active':'blue'})
#
#if menu_id == "Home":
#st.info("Under Construction. :construction_worker:")
if menu_id == "Stable Diffusion": if menu_id == "Stable Diffusion":
# set the page url and title # set the page url and title
#st.experimental_set_query_params(page='stable-diffusion') #st.experimental_set_query_params(page='stable-diffusion')
@ -227,6 +234,11 @@ def layout():
from APIServer import layout from APIServer import layout
layout() layout()
#elif menu_id == 'Barfi/BaklavaJS':
#set_page_title("Barfi/BaklavaJS - Stable Diffusion Playground")
#from barfi_baklavajs import layout
#layout()
elif menu_id == 'Settings': elif menu_id == 'Settings':
set_page_title("Settings - Stable Diffusion Playground") set_page_title("Settings - Stable Diffusion Playground")