mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-14 22:13:41 +03:00
Added basic image editor for the img2img tab. (#1613)
Added an image editor for img2img, for now its pretty basic and has some bugs but it should do the job better than having to create a mask using an external software and then importing it to the UI every time.
This commit is contained in:
commit
e9a4e46f66
@ -230,7 +230,8 @@ img2img:
|
||||
step: 0.01
|
||||
# 0: Keep masked area
|
||||
# 1: Regenerate only masked area
|
||||
mask_mode: 0
|
||||
mask_mode: 1
|
||||
noise_mode: "Matched Noise"
|
||||
mask_restore: False
|
||||
# 0: Just resize
|
||||
# 1: Crop and resize
|
||||
|
@ -33,6 +33,7 @@ from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
# streamlit components
|
||||
from custom_components import sygil_suggestions
|
||||
from streamlit_drawable_canvas import st_canvas
|
||||
|
||||
# Temp imports
|
||||
|
||||
@ -381,7 +382,7 @@ def layout():
|
||||
|
||||
|
||||
# creating the page layout using columns
|
||||
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="medium")
|
||||
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([2,4,4], gap="medium")
|
||||
|
||||
with col1_img2img_layout:
|
||||
# If we have custom models available on the "models/custom"
|
||||
@ -426,7 +427,7 @@ def layout():
|
||||
mask_expander = st.empty()
|
||||
with mask_expander.expander("Mask"):
|
||||
mask_mode_list = ["Mask", "Inverted mask", "Image alpha"]
|
||||
mask_mode = st.selectbox("Mask Mode", mask_mode_list,
|
||||
mask_mode = st.selectbox("Mask Mode", mask_mode_list, index=st.session_state["defaults"].img2img.mask_mode,
|
||||
help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\
|
||||
\"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
|
||||
)
|
||||
@ -434,15 +435,32 @@ def layout():
|
||||
|
||||
|
||||
noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
|
||||
noise_mode = st.selectbox(
|
||||
"Noise Mode", noise_mode_list,
|
||||
help=""
|
||||
)
|
||||
noise_mode = noise_mode_list.index(noise_mode)
|
||||
noise_mode = st.selectbox("Noise Mode", noise_mode_list, index=noise_mode_list.index(st.session_state['defaults'].img2img.noise_mode), help="")
|
||||
#noise_mode = noise_mode_list.index(noise_mode)
|
||||
find_noise_steps = st.number_input("Find Noise Steps", value=st.session_state['defaults'].img2img.find_noise_steps.value,
|
||||
min_value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
|
||||
step=st.session_state['defaults'].img2img.find_noise_steps.step)
|
||||
|
||||
# Specify canvas parameters in application
|
||||
drawing_mode = st.selectbox(
|
||||
"Drawing tool:",
|
||||
(
|
||||
"freedraw",
|
||||
"transform",
|
||||
#"line",
|
||||
"rect",
|
||||
"circle",
|
||||
#"polygon",
|
||||
),
|
||||
)
|
||||
|
||||
stroke_width = st.slider("Stroke width: ", 1, 100, 50)
|
||||
stroke_color = st.color_picker("Stroke color hex: ", value="#EEEEEE")
|
||||
bg_color = st.color_picker("Background color hex: ", "#7B6E6E")
|
||||
|
||||
display_toolbar = st.checkbox("Display toolbar", True)
|
||||
#realtime_update = st.checkbox("Update in realtime", True)
|
||||
|
||||
with st.expander("Batch Options"):
|
||||
st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].img2img.batch_count.value,
|
||||
help="How many iterations or batches of images to generate in total.")
|
||||
@ -583,55 +601,63 @@ def layout():
|
||||
editor_image = st.empty()
|
||||
st.session_state["editor_image"] = editor_image
|
||||
|
||||
st.form_submit_button("Refresh")
|
||||
|
||||
#if "canvas" not in st.session_state:
|
||||
st.session_state["canvas"] = st.empty()
|
||||
|
||||
masked_image_holder = st.empty()
|
||||
image_holder = st.empty()
|
||||
|
||||
st.form_submit_button("Refresh")
|
||||
|
||||
uploaded_images = st.file_uploader(
|
||||
"Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
|
||||
help="Upload an image which will be used for the image to image generation.",
|
||||
)
|
||||
if uploaded_images:
|
||||
image = Image.open(uploaded_images).convert('RGBA')
|
||||
image = Image.open(uploaded_images).convert('RGB')
|
||||
new_img = image.resize((width, height))
|
||||
image_holder.image(new_img)
|
||||
#image_holder.image(new_img)
|
||||
|
||||
mask_holder = st.empty()
|
||||
#mask_holder = st.empty()
|
||||
|
||||
uploaded_masks = st.file_uploader(
|
||||
"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
|
||||
help="Upload an mask image which will be used for masking the image to image generation.",
|
||||
#uploaded_masks = st.file_uploader(
|
||||
#"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
|
||||
#help="Upload an mask image which will be used for masking the image to image generation.",
|
||||
#)
|
||||
|
||||
#
|
||||
# Create a canvas component
|
||||
with st.session_state["canvas"]:
|
||||
st.session_state["uploaded_masks"] = st_canvas(
|
||||
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
||||
stroke_width=stroke_width,
|
||||
stroke_color=stroke_color,
|
||||
background_color=bg_color,
|
||||
background_image=image if uploaded_images else None,
|
||||
update_streamlit=True,
|
||||
width=width,
|
||||
height=height,
|
||||
drawing_mode=drawing_mode,
|
||||
initial_drawing=st.session_state["uploaded_masks"].json_data if "uploaded_masks" in st.session_state else None,
|
||||
display_toolbar= display_toolbar,
|
||||
key="full_app",
|
||||
)
|
||||
if uploaded_masks:
|
||||
mask_expander.expander("Mask", expanded=True)
|
||||
mask = Image.open(uploaded_masks)
|
||||
if mask.mode == "RGBA":
|
||||
mask = mask.convert('RGBA')
|
||||
background = Image.new('RGBA', mask.size, (0, 0, 0))
|
||||
mask = Image.alpha_composite(background, mask)
|
||||
mask = mask.resize((width, height))
|
||||
mask_holder.image(mask)
|
||||
|
||||
if uploaded_images and uploaded_masks:
|
||||
if mask_mode != 2:
|
||||
final_img = new_img.copy()
|
||||
alpha_layer = mask.convert('L')
|
||||
strength = st.session_state["denoising_strength"]
|
||||
if mask_mode == 0:
|
||||
alpha_layer = ImageOps.invert(alpha_layer)
|
||||
alpha_layer = alpha_layer.point(lambda a: a * strength)
|
||||
alpha_layer = ImageOps.invert(alpha_layer)
|
||||
elif mask_mode == 1:
|
||||
alpha_layer = alpha_layer.point(lambda a: a * strength)
|
||||
alpha_layer = ImageOps.invert(alpha_layer)
|
||||
#try:
|
||||
##print (type(st.session_state["uploaded_masks"]))
|
||||
#if st.session_state["uploaded_masks"] != None:
|
||||
#mask_expander.expander("Mask", expanded=True)
|
||||
#mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
|
||||
|
||||
final_img.putalpha(alpha_layer)
|
||||
|
||||
with masked_image_holder.container():
|
||||
st.text("Masked Image Preview")
|
||||
st.image(final_img)
|
||||
#st.image(mask)
|
||||
|
||||
#if mask.mode == "RGBA":
|
||||
#mask = mask.convert('RGBA')
|
||||
#background = Image.new('RGBA', mask.size, (0, 0, 0))
|
||||
#mask = Image.alpha_composite(background, mask)
|
||||
#mask = mask.resize((width, height))
|
||||
#except AttributeError:
|
||||
#pass
|
||||
|
||||
with col3_img2img_layout:
|
||||
result_tab = st.tabs(["Result"])
|
||||
@ -645,7 +671,6 @@ def layout():
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
|
||||
message = st.empty()
|
||||
|
||||
#if uploaded_images:
|
||||
@ -666,14 +691,17 @@ def layout():
|
||||
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
|
||||
|
||||
if uploaded_images:
|
||||
image = Image.open(uploaded_images).convert('RGBA')
|
||||
new_img = image.resize((width, height))
|
||||
#img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
#image = Image.fromarray(image).convert('RGBA')
|
||||
#new_img = image.resize((width, height))
|
||||
###img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
#image_holder.image(new_img)
|
||||
new_mask = None
|
||||
if uploaded_masks:
|
||||
mask = Image.open(uploaded_masks).convert('RGBA')
|
||||
|
||||
if st.session_state["uploaded_masks"]:
|
||||
mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
|
||||
new_mask = mask.resize((width, height))
|
||||
|
||||
#masked_image_holder.image(new_mask)
|
||||
try:
|
||||
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
|
||||
mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],
|
||||
|
@ -14,15 +14,13 @@
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
import collections.abc
|
||||
#from webui_streamlit import st
|
||||
import gfpgan
|
||||
import hydralit as st
|
||||
|
||||
|
||||
# streamlit imports
|
||||
from streamlit import StopException, StreamlitAPIException
|
||||
from streamlit.runtime.scriptrunner import script_run_context
|
||||
#from streamlit.runtime.scriptrunner import script_run_context
|
||||
|
||||
#streamlit components section
|
||||
from streamlit_server_state import server_state, server_state_lock
|
||||
@ -35,7 +33,7 @@ import streamlit_nested_layout
|
||||
import warnings
|
||||
import json
|
||||
|
||||
import base64
|
||||
import base64, cv2
|
||||
import os, sys, re, random, datetime, time, math, glob, toml
|
||||
import gc
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageFilter
|
||||
@ -68,15 +66,31 @@ import piexif.helper
|
||||
from tqdm import trange
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import ismap
|
||||
from abc import ABC, abstractmethod
|
||||
#from abc import ABC, abstractmethod
|
||||
from typing import Dict, Union
|
||||
from io import BytesIO
|
||||
from packaging import version
|
||||
from uuid import uuid4
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
#import librosa
|
||||
from logger import logger, set_logger_verbosity, quiesce_logger
|
||||
#from logger import logger, set_logger_verbosity, quiesce_logger
|
||||
#from loguru import logger
|
||||
|
||||
from nataili.inference.compvis.img2img import img2img
|
||||
from nataili.model_manager import ModelManager
|
||||
from nataili.inference.compvis.txt2img import txt2img
|
||||
from nataili.util.cache import torch_gc
|
||||
from nataili.util.logger import logger, set_logger_verbosity, quiesce_logger
|
||||
|
||||
try:
|
||||
from realesrgan import RealESRGANer
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
except ImportError as e:
|
||||
logger.error("You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
|
||||
"pip install realesrgan")
|
||||
|
||||
# Temp imports
|
||||
#from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
@ -84,14 +98,6 @@ from logger import logger, set_logger_verbosity, quiesce_logger
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# we make a log file where we store the logs
|
||||
logger.add("logs/log_{time:MM-DD-YYYY!UTC}.log", rotation="8 MB", compression="zip", level='INFO') # Once the file is too old, it's rotated
|
||||
logger.add(sys.stderr, diagnose=True)
|
||||
logger.add(sys.stdout)
|
||||
logger.enable("")
|
||||
|
||||
#
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
@ -112,6 +118,8 @@ mimetypes.add_type('application/javascript', '.js')
|
||||
opt_C = 4
|
||||
opt_f = 8
|
||||
|
||||
# The model manager loads and unloads the SD models and has features to download them or find their location
|
||||
#model_manager = ModelManager()
|
||||
|
||||
def load_configs():
|
||||
if not "defaults" in st.session_state:
|
||||
@ -269,6 +277,33 @@ def make_grid(n_items=5, n_cols=5):
|
||||
|
||||
return cols
|
||||
|
||||
|
||||
def merge(file1, file2, out, weight):
|
||||
alpha = (weight)/100
|
||||
if not(file1.endswith(".ckpt")):
|
||||
file1 += ".ckpt"
|
||||
if not(file2.endswith(".ckpt")):
|
||||
file2 += ".ckpt"
|
||||
if not(out.endswith(".ckpt")):
|
||||
out += ".ckpt"
|
||||
#Load Models
|
||||
model_0 = torch.load(file1)
|
||||
model_1 = torch.load(file2)
|
||||
theta_0 = model_0['state_dict']
|
||||
theta_1 = model_1['state_dict']
|
||||
|
||||
for key in theta_0.keys():
|
||||
if 'model' in key and key in theta_1:
|
||||
theta_0[key] = (alpha) * theta_0[key] + (1-alpha) * theta_1[key]
|
||||
|
||||
logger.info("RUNNING...\n(STAGE 2)")
|
||||
|
||||
for key in theta_1.keys():
|
||||
if 'model' in key and key not in theta_0:
|
||||
theta_0[key] = theta_1[key]
|
||||
torch.save(model_0, out)
|
||||
|
||||
|
||||
def human_readable_size(size, decimal_places=3):
|
||||
"""Return a human readable size from bytes."""
|
||||
for unit in ['B','KB','MB','GB','TB']:
|
||||
@ -282,6 +317,8 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m
|
||||
CustomModel_available=False, custom_model="Stable Diffusion v1.5"):
|
||||
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
|
||||
|
||||
#model_manager.init()
|
||||
|
||||
logger.info("Loading models.")
|
||||
|
||||
if "progress_bar_text" in st.session_state:
|
||||
@ -1350,6 +1387,77 @@ def load_RealESRGAN(model_name: str):
|
||||
|
||||
return server_state['RealESRGAN']
|
||||
|
||||
#
|
||||
class RealESRGANModel(nn.Module):
|
||||
def __init__(self, model_path, tile=0, tile_pad=10, pre_pad=0, fp32=False):
|
||||
super().__init__()
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
"You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
|
||||
"pip install realesrgan"
|
||||
)
|
||||
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
self.upsampler = RealESRGANer(
|
||||
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=tile_pad, pre_pad=pre_pad, half=not fp32
|
||||
)
|
||||
|
||||
def forward(self, image, outscale=4, convert_to_pil=True):
|
||||
"""Upsample an image array or path.
|
||||
Args:
|
||||
image (Union[np.ndarray, str]): Either a np array or an image path. np array is assumed to be in RGB format,
|
||||
and we convert it to BGR.
|
||||
outscale (int, optional): Amount to upscale the image. Defaults to 4.
|
||||
convert_to_pil (bool, optional): If True, return PIL image. Otherwise, return numpy array (BGR). Defaults to True.
|
||||
Returns:
|
||||
Union[np.ndarray, PIL.Image.Image]: An upsampled version of the input image.
|
||||
"""
|
||||
if isinstance(image, (str, Path)):
|
||||
img = cv2.imread(image, cv2.IMREAD_UNCHANGED)
|
||||
else:
|
||||
img = image
|
||||
img = (img * 255).round().astype("uint8")
|
||||
img = img[:, :, ::-1]
|
||||
|
||||
image, _ = self.upsampler.enhance(img, outscale=outscale)
|
||||
|
||||
if convert_to_pil:
|
||||
image = Image.fromarray(image[:, :, ::-1])
|
||||
|
||||
return image
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path="nateraw/real-esrgan"):
|
||||
"""Initialize a pretrained Real-ESRGAN upsampler.
|
||||
Args:
|
||||
model_name_or_path (str, optional): The Hugging Face repo ID or path to local model. Defaults to 'nateraw/real-esrgan'.
|
||||
Returns:
|
||||
PipelineRealESRGAN: An instance of `PipelineRealESRGAN` instantiated from pretrained model.
|
||||
"""
|
||||
# reuploaded form official ones mentioned here:
|
||||
# https://github.com/xinntao/Real-ESRGAN
|
||||
if Path(model_name_or_path).exists():
|
||||
file = model_name_or_path
|
||||
else:
|
||||
file = hf_hub_download(model_name_or_path, "RealESRGAN_x4plus.pth")
|
||||
return cls(file)
|
||||
|
||||
def upsample_imagefolder(self, in_dir, out_dir, suffix="out", outfile_ext=".png"):
|
||||
in_dir, out_dir = Path(in_dir), Path(out_dir)
|
||||
if not in_dir.exists():
|
||||
raise FileNotFoundError(f"Provided input directory {in_dir} does not exist")
|
||||
|
||||
out_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
image_paths = [x for x in in_dir.glob("*") if x.suffix.lower() in [".png", ".jpg", ".jpeg"]]
|
||||
for image in image_paths:
|
||||
im = self(str(image))
|
||||
out_filepath = out_dir / (image.stem + suffix + outfile_ext)
|
||||
im.save(out_filepath)
|
||||
|
||||
#
|
||||
@retry(tries=5)
|
||||
def load_LDSR(model_name="model", config="project", checking=False):
|
||||
@ -1744,6 +1852,9 @@ def seed_to_int(s):
|
||||
if s is None or s == '':
|
||||
return random.randint(0, 2**32 - 1)
|
||||
|
||||
if ',' in s:
|
||||
s = s.split(',')
|
||||
|
||||
if type(s) is list:
|
||||
seed_list = []
|
||||
for seed in s:
|
||||
@ -1955,6 +2066,7 @@ def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, widt
|
||||
|
||||
filename_i = os.path.join(sample_path_i, filename)
|
||||
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state['defaults'].general.save_metadata or write_info_files:
|
||||
# toggles differ for txt2img vs. img2img:
|
||||
offset = 0 if init_img is None else 2
|
||||
@ -2563,7 +2675,7 @@ def process_images(
|
||||
#output_images.append(image)
|
||||
#if simple_templating:
|
||||
#grid_captions.append( captions[i] )
|
||||
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state['defaults'].general.optimized:
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
server_state["modelFS"].to("cpu")
|
||||
|
@ -10,7 +10,6 @@ import time
|
||||
import json
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
@ -22,59 +21,39 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from torch import nn
|
||||
|
||||
from .upsampling import RealESRGANModel
|
||||
|
||||
from sd_utils import RealESRGANModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_spec_norm(wav, sr, n_mels=512, hop_length=704):
|
||||
"""Obtain maximum value for each time-frame in Mel Spectrogram,
|
||||
and normalize between 0 and 1
|
||||
def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=1.0, smooth=0.0):
|
||||
y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
|
||||
|
||||
Borrowed from lucid sonic dreams repo. In there, they programatically determine hop length
|
||||
but I really didn't understand what was going on so I removed it and hard coded the output.
|
||||
"""
|
||||
# librosa.stft hardcoded defaults...
|
||||
# n_fft defaults to 2048
|
||||
# hop length is win_length // 4
|
||||
# win_length defaults to n_fft
|
||||
D = librosa.stft(y, n_fft=2048, hop_length=2048 // 4, win_length=2048)
|
||||
|
||||
# Generate Mel Spectrogram
|
||||
spec_raw = librosa.feature.melspectrogram(y=wav, sr=sr, n_mels=n_mels, hop_length=hop_length)
|
||||
# Extract percussive elements
|
||||
D_harmonic, D_percussive = librosa.decompose.hpss(D, margin=margin)
|
||||
y_percussive = librosa.istft(D_percussive, length=len(y))
|
||||
|
||||
# Obtain maximum value per time-frame
|
||||
# Get normalized melspectrogram
|
||||
spec_raw = librosa.feature.melspectrogram(y=y_percussive, sr=sr)
|
||||
spec_max = np.amax(spec_raw, axis=0)
|
||||
|
||||
# Normalize all values between 0 and 1
|
||||
spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max)
|
||||
|
||||
return spec_norm
|
||||
# Resize cumsum of spec norm to our desired number of interpolation frames
|
||||
x_norm = np.linspace(0, spec_norm.shape[-1], spec_norm.shape[-1])
|
||||
y_norm = np.cumsum(spec_norm)
|
||||
y_norm /= y_norm[-1]
|
||||
x_resize = np.linspace(0, y_norm.shape[-1], int(duration*fps))
|
||||
|
||||
T = np.interp(x_resize, x_norm, y_norm)
|
||||
|
||||
def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=(1.0, 5.0)):
|
||||
"""Get the array that will be used to determine how much to interpolate between images.
|
||||
|
||||
Normally, this is just a linspace between 0 and 1 for the number of frames to generate. In this case,
|
||||
we want to use the amplitude of the audio to determine how much to interpolate between images.
|
||||
|
||||
So, here we:
|
||||
1. Load the audio file
|
||||
2. Split the audio into harmonic and percussive components
|
||||
3. Get the normalized amplitude of the percussive component, resized to the number of frames
|
||||
4. Get the cumulative sum of the amplitude array
|
||||
5. Normalize the cumulative sum between 0 and 1
|
||||
6. Return the array
|
||||
|
||||
I honestly have no clue what I'm doing here. Suggestions welcome.
|
||||
"""
|
||||
y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
|
||||
wav_harmonic, wav_percussive = librosa.effects.hpss(y, margin=margin)
|
||||
|
||||
# Apparently n_mels is supposed to be input shape but I don't think it matters here?
|
||||
frame_duration = int(sr / fps)
|
||||
wav_norm = get_spec_norm(wav_percussive, sr, n_mels=512, hop_length=frame_duration)
|
||||
amplitude_arr = np.resize(wav_norm, int(duration * fps))
|
||||
T = np.cumsum(amplitude_arr)
|
||||
T /= T[-1]
|
||||
T[0] = 0.0
|
||||
return T
|
||||
# Apply smoothing
|
||||
return T * (1 - smooth) + np.linspace(0.0, 1.0, T.shape[0]) * smooth
|
||||
|
||||
|
||||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
@ -130,7 +109,6 @@ def make_video_pyav(
|
||||
frame = pil_to_tensor(Image.open(img)).unsqueeze(0)
|
||||
frames = frame if frames is None else torch.cat([frames, frame])
|
||||
else:
|
||||
|
||||
frames = frames_or_frame_dir
|
||||
|
||||
# TCHW -> THWC
|
||||
@ -208,6 +186,16 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@ -251,6 +239,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
@ -259,12 +249,13 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
text_embeddings: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
prompt (`str` or `List[str]`, *optional*, defaults to `None`):
|
||||
The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
@ -278,6 +269,11 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
@ -300,8 +296,10 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
text_embeddings(`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings.
|
||||
text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):
|
||||
Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
|
||||
`prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
|
||||
the supplied `prompt`.
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
@ -340,7 +338,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
print(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
@ -349,21 +347,51 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
else:
|
||||
batch_size = text_embeddings.shape[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
# HACK - Not setting text_input_ids here when walking, so hard coding to max length of tokenizer
|
||||
# TODO - Determine if this is OK to do
|
||||
# max_length = text_input_ids.shape[-1]
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = self.tokenizer.model_max_length
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
@ -374,19 +402,20 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_device = "cpu" if self.device.type == "mps" else self.device
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
generator=generator,
|
||||
device=latents_device,
|
||||
dtype=text_embeddings.dtype,
|
||||
if self.device.type == "mps":
|
||||
# randn does not exist on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
||||
self.device
|
||||
)
|
||||
else:
|
||||
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(latents_device)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
@ -431,12 +460,19 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
@ -449,16 +485,9 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size):
|
||||
embeds_a = self.embed_text(prompt_a)
|
||||
embeds_b = self.embed_text(prompt_b)
|
||||
latents_a = torch.randn(
|
||||
noise_shape,
|
||||
device=self.device,
|
||||
generator=torch.Generator(device=self.device).manual_seed(seed_a),
|
||||
)
|
||||
latents_b = torch.randn(
|
||||
noise_shape,
|
||||
device=self.device,
|
||||
generator=torch.Generator(device=self.device).manual_seed(seed_b),
|
||||
)
|
||||
|
||||
latents_a = self.init_noise(seed_a, noise_shape)
|
||||
latents_b = self.init_noise(seed_b, noise_shape)
|
||||
|
||||
batch_idx = 0
|
||||
embeds_batch, noise_batch = None, None
|
||||
@ -477,7 +506,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
torch.cuda.empty_cache()
|
||||
embeds_batch, noise_batch = None, None
|
||||
|
||||
def generate_interpolation_clip(
|
||||
def make_clip_frames(
|
||||
self,
|
||||
prompt_a: str,
|
||||
prompt_b: str,
|
||||
@ -530,7 +559,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
eta=eta,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type="pil" if not upsample else "numpy",
|
||||
)["sample"]
|
||||
)["images"]
|
||||
|
||||
for image in outputs:
|
||||
frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index)
|
||||
@ -557,6 +586,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
resume: Optional[bool] = False,
|
||||
audio_filepath: str = None,
|
||||
audio_start_sec: Optional[Union[int, float]] = None,
|
||||
margin: Optional[float] = 1.0,
|
||||
smooth: Optional[float] = 0.0,
|
||||
):
|
||||
"""Generate a video from a sequence of prompts and seeds. Optionally, add audio to the
|
||||
video to interpolate to the intensity of the audio.
|
||||
@ -603,13 +634,17 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
Optional path to an audio file to influence the interpolation rate.
|
||||
audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0):
|
||||
Global start time of the provided audio_filepath.
|
||||
margin (Optional[float], *optional*, defaults to 1.0):
|
||||
Margin from librosa hpss to use for audio interpolation.
|
||||
smooth (Optional[float], *optional*, defaults to 0.0):
|
||||
Smoothness of the audio interpolation. 1.0 means linear interpolation.
|
||||
|
||||
This function will create sub directories for each prompt and seed pair.
|
||||
|
||||
For example, if you provide the following prompts and seeds:
|
||||
|
||||
```
|
||||
prompts = ['a', 'b', 'c']
|
||||
prompts = ['a dog', 'a cat', 'a bird']
|
||||
seeds = [1, 2, 3]
|
||||
num_interpolation_steps = 5
|
||||
output_dir = 'output_dir'
|
||||
@ -722,7 +757,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
|
||||
audio_duration = num_step / fps
|
||||
|
||||
self.generate_interpolation_clip(
|
||||
self.make_clip_frames(
|
||||
prompt_a,
|
||||
prompt_b,
|
||||
seed_a,
|
||||
@ -742,7 +777,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
offset=audio_offset,
|
||||
duration=audio_duration,
|
||||
fps=fps,
|
||||
margin=(1.0, 5.0),
|
||||
margin=margin,
|
||||
smooth=smooth,
|
||||
)
|
||||
if audio_filepath
|
||||
else None,
|
||||
@ -783,6 +819,23 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
return embed
|
||||
|
||||
def init_noise(self, seed, noise_shape):
|
||||
"""Helper to initialize noise"""
|
||||
# randn does not exist on mps, so we create noise on CPU here and move it to the device after initialization
|
||||
if self.device.type == "mps":
|
||||
noise = torch.randn(
|
||||
noise_shape,
|
||||
device='cpu',
|
||||
generator=torch.Generator(device='cpu').manual_seed(seed),
|
||||
).to(self.device)
|
||||
else:
|
||||
noise = torch.randn(
|
||||
noise_shape,
|
||||
device=self.device,
|
||||
generator=torch.Generator(device=self.device).manual_seed(seed),
|
||||
)
|
||||
return noise
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, tiled=False, **kwargs):
|
||||
"""Same as diffusers `from_pretrained` but with tiled option, which makes images tilable"""
|
||||
@ -799,15 +852,6 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
|
||||
patch_conv(padding_mode="circular")
|
||||
|
||||
return super().from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
class NoCheck(ModelMixin):
|
||||
"""Can be used in place of safety checker. Use responsibly and at your own risk."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_parameter(name="asdf", param=torch.nn.Parameter(torch.randn(3)))
|
||||
|
||||
def forward(self, images=None, **kwargs):
|
||||
return images, [False]
|
||||
pipeline = super().from_pretrained(*args, **kwargs)
|
||||
pipeline.tiled = tiled
|
||||
return pipeline
|
@ -410,7 +410,7 @@ def layout():
|
||||
sygil_suggestions.suggestion_area(placeholder)
|
||||
|
||||
# creating the page layout using columns
|
||||
col1, col2, col3 = st.columns([1,2,1], gap="large")
|
||||
col1, col2, col3 = st.columns([2,5,2], gap="large")
|
||||
|
||||
with col1:
|
||||
width = st.slider("Width:", min_value=st.session_state['defaults'].txt2img.width.min_value, max_value=st.session_state['defaults'].txt2img.width.max_value,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -125,8 +125,11 @@ def layout():
|
||||
{'id': 'Stable Diffusion', 'label': 'Stable Diffusion', 'icon': 'bi bi-grid-1x2-fill'},
|
||||
{'id': 'Textual Inversion', 'label': 'Textual Inversion', 'icon': 'bi bi-lightbulb-fill'},
|
||||
{'id': 'Model Manager', 'label': 'Model Manager', 'icon': 'bi bi-cloud-arrow-down-fill'},
|
||||
#{'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
|
||||
{'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
|
||||
{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
|
||||
#{'id': 'Barfi/BaklavaJS', 'label': 'Barfi/BaklavaJS', 'icon': 'bi bi-diagram-3-fill'},
|
||||
#{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
|
||||
]},
|
||||
{'id': 'Settings', 'label': 'Settings', 'icon': 'bi bi-gear-fill'},
|
||||
#{'icon': "fa-solid fa-radar",'label':"Dropdown1", 'submenu':[
|
||||
# {'id':' subid11','icon': "fa fa-paperclip", 'label':"Sub-item 1"},{'id':'subid12','icon': "💀", 'label':"Sub-item 2"},{'id':'subid13','icon': "fa fa-database", 'label':"Sub-item 3"}]},
|
||||
@ -172,6 +175,10 @@ def layout():
|
||||
#horizontal_orientation=False,
|
||||
#override_theme={'txc_inactive': 'white','menu_background':'#111', 'stVerticalBlock': '#111','txc_active':'yellow','option_active':'blue'})
|
||||
|
||||
#
|
||||
#if menu_id == "Home":
|
||||
#st.info("Under Construction. :construction_worker:")
|
||||
|
||||
if menu_id == "Stable Diffusion":
|
||||
# set the page url and title
|
||||
#st.experimental_set_query_params(page='stable-diffusion')
|
||||
@ -227,6 +234,11 @@ def layout():
|
||||
from APIServer import layout
|
||||
layout()
|
||||
|
||||
#elif menu_id == 'Barfi/BaklavaJS':
|
||||
#set_page_title("Barfi/BaklavaJS - Stable Diffusion Playground")
|
||||
#from barfi_baklavajs import layout
|
||||
#layout()
|
||||
|
||||
elif menu_id == 'Settings':
|
||||
set_page_title("Settings - Stable Diffusion Playground")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user