mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-14 14:05:36 +03:00
Added basic image editor for the img2img tab. (#1613)
Added an image editor for img2img, for now its pretty basic and has some bugs but it should do the job better than having to create a mask using an external software and then importing it to the UI every time.
This commit is contained in:
commit
e9a4e46f66
@ -230,7 +230,8 @@ img2img:
|
|||||||
step: 0.01
|
step: 0.01
|
||||||
# 0: Keep masked area
|
# 0: Keep masked area
|
||||||
# 1: Regenerate only masked area
|
# 1: Regenerate only masked area
|
||||||
mask_mode: 0
|
mask_mode: 1
|
||||||
|
noise_mode: "Matched Noise"
|
||||||
mask_restore: False
|
mask_restore: False
|
||||||
# 0: Just resize
|
# 0: Just resize
|
||||||
# 1: Crop and resize
|
# 1: Crop and resize
|
||||||
|
@ -33,6 +33,7 @@ from ldm.models.diffusion.plms import PLMSSampler
|
|||||||
|
|
||||||
# streamlit components
|
# streamlit components
|
||||||
from custom_components import sygil_suggestions
|
from custom_components import sygil_suggestions
|
||||||
|
from streamlit_drawable_canvas import st_canvas
|
||||||
|
|
||||||
# Temp imports
|
# Temp imports
|
||||||
|
|
||||||
@ -381,7 +382,7 @@ def layout():
|
|||||||
|
|
||||||
|
|
||||||
# creating the page layout using columns
|
# creating the page layout using columns
|
||||||
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="medium")
|
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([2,4,4], gap="medium")
|
||||||
|
|
||||||
with col1_img2img_layout:
|
with col1_img2img_layout:
|
||||||
# If we have custom models available on the "models/custom"
|
# If we have custom models available on the "models/custom"
|
||||||
@ -426,7 +427,7 @@ def layout():
|
|||||||
mask_expander = st.empty()
|
mask_expander = st.empty()
|
||||||
with mask_expander.expander("Mask"):
|
with mask_expander.expander("Mask"):
|
||||||
mask_mode_list = ["Mask", "Inverted mask", "Image alpha"]
|
mask_mode_list = ["Mask", "Inverted mask", "Image alpha"]
|
||||||
mask_mode = st.selectbox("Mask Mode", mask_mode_list,
|
mask_mode = st.selectbox("Mask Mode", mask_mode_list, index=st.session_state["defaults"].img2img.mask_mode,
|
||||||
help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\
|
help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\
|
||||||
\"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
|
\"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
|
||||||
)
|
)
|
||||||
@ -434,15 +435,32 @@ def layout():
|
|||||||
|
|
||||||
|
|
||||||
noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
|
noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
|
||||||
noise_mode = st.selectbox(
|
noise_mode = st.selectbox("Noise Mode", noise_mode_list, index=noise_mode_list.index(st.session_state['defaults'].img2img.noise_mode), help="")
|
||||||
"Noise Mode", noise_mode_list,
|
#noise_mode = noise_mode_list.index(noise_mode)
|
||||||
help=""
|
|
||||||
)
|
|
||||||
noise_mode = noise_mode_list.index(noise_mode)
|
|
||||||
find_noise_steps = st.number_input("Find Noise Steps", value=st.session_state['defaults'].img2img.find_noise_steps.value,
|
find_noise_steps = st.number_input("Find Noise Steps", value=st.session_state['defaults'].img2img.find_noise_steps.value,
|
||||||
min_value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
|
min_value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
|
||||||
step=st.session_state['defaults'].img2img.find_noise_steps.step)
|
step=st.session_state['defaults'].img2img.find_noise_steps.step)
|
||||||
|
|
||||||
|
# Specify canvas parameters in application
|
||||||
|
drawing_mode = st.selectbox(
|
||||||
|
"Drawing tool:",
|
||||||
|
(
|
||||||
|
"freedraw",
|
||||||
|
"transform",
|
||||||
|
#"line",
|
||||||
|
"rect",
|
||||||
|
"circle",
|
||||||
|
#"polygon",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
stroke_width = st.slider("Stroke width: ", 1, 100, 50)
|
||||||
|
stroke_color = st.color_picker("Stroke color hex: ", value="#EEEEEE")
|
||||||
|
bg_color = st.color_picker("Background color hex: ", "#7B6E6E")
|
||||||
|
|
||||||
|
display_toolbar = st.checkbox("Display toolbar", True)
|
||||||
|
#realtime_update = st.checkbox("Update in realtime", True)
|
||||||
|
|
||||||
with st.expander("Batch Options"):
|
with st.expander("Batch Options"):
|
||||||
st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].img2img.batch_count.value,
|
st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].img2img.batch_count.value,
|
||||||
help="How many iterations or batches of images to generate in total.")
|
help="How many iterations or batches of images to generate in total.")
|
||||||
@ -583,55 +601,63 @@ def layout():
|
|||||||
editor_image = st.empty()
|
editor_image = st.empty()
|
||||||
st.session_state["editor_image"] = editor_image
|
st.session_state["editor_image"] = editor_image
|
||||||
|
|
||||||
|
st.form_submit_button("Refresh")
|
||||||
|
|
||||||
|
#if "canvas" not in st.session_state:
|
||||||
|
st.session_state["canvas"] = st.empty()
|
||||||
|
|
||||||
masked_image_holder = st.empty()
|
masked_image_holder = st.empty()
|
||||||
image_holder = st.empty()
|
image_holder = st.empty()
|
||||||
|
|
||||||
st.form_submit_button("Refresh")
|
|
||||||
|
|
||||||
uploaded_images = st.file_uploader(
|
uploaded_images = st.file_uploader(
|
||||||
"Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
|
"Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
|
||||||
help="Upload an image which will be used for the image to image generation.",
|
help="Upload an image which will be used for the image to image generation.",
|
||||||
)
|
)
|
||||||
if uploaded_images:
|
if uploaded_images:
|
||||||
image = Image.open(uploaded_images).convert('RGBA')
|
image = Image.open(uploaded_images).convert('RGB')
|
||||||
new_img = image.resize((width, height))
|
new_img = image.resize((width, height))
|
||||||
image_holder.image(new_img)
|
#image_holder.image(new_img)
|
||||||
|
|
||||||
mask_holder = st.empty()
|
#mask_holder = st.empty()
|
||||||
|
|
||||||
uploaded_masks = st.file_uploader(
|
#uploaded_masks = st.file_uploader(
|
||||||
"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
|
#"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
|
||||||
help="Upload an mask image which will be used for masking the image to image generation.",
|
#help="Upload an mask image which will be used for masking the image to image generation.",
|
||||||
|
#)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Create a canvas component
|
||||||
|
with st.session_state["canvas"]:
|
||||||
|
st.session_state["uploaded_masks"] = st_canvas(
|
||||||
|
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
||||||
|
stroke_width=stroke_width,
|
||||||
|
stroke_color=stroke_color,
|
||||||
|
background_color=bg_color,
|
||||||
|
background_image=image if uploaded_images else None,
|
||||||
|
update_streamlit=True,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
drawing_mode=drawing_mode,
|
||||||
|
initial_drawing=st.session_state["uploaded_masks"].json_data if "uploaded_masks" in st.session_state else None,
|
||||||
|
display_toolbar= display_toolbar,
|
||||||
|
key="full_app",
|
||||||
)
|
)
|
||||||
if uploaded_masks:
|
|
||||||
mask_expander.expander("Mask", expanded=True)
|
|
||||||
mask = Image.open(uploaded_masks)
|
|
||||||
if mask.mode == "RGBA":
|
|
||||||
mask = mask.convert('RGBA')
|
|
||||||
background = Image.new('RGBA', mask.size, (0, 0, 0))
|
|
||||||
mask = Image.alpha_composite(background, mask)
|
|
||||||
mask = mask.resize((width, height))
|
|
||||||
mask_holder.image(mask)
|
|
||||||
|
|
||||||
if uploaded_images and uploaded_masks:
|
#try:
|
||||||
if mask_mode != 2:
|
##print (type(st.session_state["uploaded_masks"]))
|
||||||
final_img = new_img.copy()
|
#if st.session_state["uploaded_masks"] != None:
|
||||||
alpha_layer = mask.convert('L')
|
#mask_expander.expander("Mask", expanded=True)
|
||||||
strength = st.session_state["denoising_strength"]
|
#mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
|
||||||
if mask_mode == 0:
|
|
||||||
alpha_layer = ImageOps.invert(alpha_layer)
|
|
||||||
alpha_layer = alpha_layer.point(lambda a: a * strength)
|
|
||||||
alpha_layer = ImageOps.invert(alpha_layer)
|
|
||||||
elif mask_mode == 1:
|
|
||||||
alpha_layer = alpha_layer.point(lambda a: a * strength)
|
|
||||||
alpha_layer = ImageOps.invert(alpha_layer)
|
|
||||||
|
|
||||||
final_img.putalpha(alpha_layer)
|
#st.image(mask)
|
||||||
|
|
||||||
with masked_image_holder.container():
|
|
||||||
st.text("Masked Image Preview")
|
|
||||||
st.image(final_img)
|
|
||||||
|
|
||||||
|
#if mask.mode == "RGBA":
|
||||||
|
#mask = mask.convert('RGBA')
|
||||||
|
#background = Image.new('RGBA', mask.size, (0, 0, 0))
|
||||||
|
#mask = Image.alpha_composite(background, mask)
|
||||||
|
#mask = mask.resize((width, height))
|
||||||
|
#except AttributeError:
|
||||||
|
#pass
|
||||||
|
|
||||||
with col3_img2img_layout:
|
with col3_img2img_layout:
|
||||||
result_tab = st.tabs(["Result"])
|
result_tab = st.tabs(["Result"])
|
||||||
@ -645,7 +671,6 @@ def layout():
|
|||||||
st.session_state["progress_bar_text"] = st.empty()
|
st.session_state["progress_bar_text"] = st.empty()
|
||||||
st.session_state["progress_bar"] = st.empty()
|
st.session_state["progress_bar"] = st.empty()
|
||||||
|
|
||||||
|
|
||||||
message = st.empty()
|
message = st.empty()
|
||||||
|
|
||||||
#if uploaded_images:
|
#if uploaded_images:
|
||||||
@ -666,14 +691,17 @@ def layout():
|
|||||||
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
|
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
|
||||||
|
|
||||||
if uploaded_images:
|
if uploaded_images:
|
||||||
image = Image.open(uploaded_images).convert('RGBA')
|
#image = Image.fromarray(image).convert('RGBA')
|
||||||
new_img = image.resize((width, height))
|
#new_img = image.resize((width, height))
|
||||||
#img_array = np.array(image) # if you want to pass it to OpenCV
|
###img_array = np.array(image) # if you want to pass it to OpenCV
|
||||||
|
#image_holder.image(new_img)
|
||||||
new_mask = None
|
new_mask = None
|
||||||
if uploaded_masks:
|
|
||||||
mask = Image.open(uploaded_masks).convert('RGBA')
|
if st.session_state["uploaded_masks"]:
|
||||||
|
mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
|
||||||
new_mask = mask.resize((width, height))
|
new_mask = mask.resize((width, height))
|
||||||
|
|
||||||
|
#masked_image_holder.image(new_mask)
|
||||||
try:
|
try:
|
||||||
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
|
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
|
||||||
mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],
|
mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],
|
||||||
|
@ -14,15 +14,13 @@
|
|||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
# base webui import and utils.
|
# base webui import and utils.
|
||||||
import collections.abc
|
|
||||||
#from webui_streamlit import st
|
#from webui_streamlit import st
|
||||||
import gfpgan
|
import gfpgan
|
||||||
import hydralit as st
|
import hydralit as st
|
||||||
|
|
||||||
|
|
||||||
# streamlit imports
|
# streamlit imports
|
||||||
from streamlit import StopException, StreamlitAPIException
|
from streamlit import StopException, StreamlitAPIException
|
||||||
from streamlit.runtime.scriptrunner import script_run_context
|
#from streamlit.runtime.scriptrunner import script_run_context
|
||||||
|
|
||||||
#streamlit components section
|
#streamlit components section
|
||||||
from streamlit_server_state import server_state, server_state_lock
|
from streamlit_server_state import server_state, server_state_lock
|
||||||
@ -35,7 +33,7 @@ import streamlit_nested_layout
|
|||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import base64
|
import base64, cv2
|
||||||
import os, sys, re, random, datetime, time, math, glob, toml
|
import os, sys, re, random, datetime, time, math, glob, toml
|
||||||
import gc
|
import gc
|
||||||
from PIL import Image, ImageFont, ImageDraw, ImageFilter
|
from PIL import Image, ImageFont, ImageDraw, ImageFilter
|
||||||
@ -68,15 +66,31 @@ import piexif.helper
|
|||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.util import ismap
|
from ldm.util import ismap
|
||||||
from abc import ABC, abstractmethod
|
#from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
from pathlib import Path
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
#import librosa
|
#import librosa
|
||||||
from logger import logger, set_logger_verbosity, quiesce_logger
|
#from logger import logger, set_logger_verbosity, quiesce_logger
|
||||||
#from loguru import logger
|
#from loguru import logger
|
||||||
|
|
||||||
|
from nataili.inference.compvis.img2img import img2img
|
||||||
|
from nataili.model_manager import ModelManager
|
||||||
|
from nataili.inference.compvis.txt2img import txt2img
|
||||||
|
from nataili.util.cache import torch_gc
|
||||||
|
from nataili.util.logger import logger, set_logger_verbosity, quiesce_logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error("You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
|
||||||
|
"pip install realesrgan")
|
||||||
|
|
||||||
# Temp imports
|
# Temp imports
|
||||||
#from basicsr.utils.registry import ARCH_REGISTRY
|
#from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
@ -84,14 +98,6 @@ from logger import logger, set_logger_verbosity, quiesce_logger
|
|||||||
# end of imports
|
# end of imports
|
||||||
#---------------------------------------------------------------------------------------------------------------
|
#---------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
# we make a log file where we store the logs
|
|
||||||
logger.add("logs/log_{time:MM-DD-YYYY!UTC}.log", rotation="8 MB", compression="zip", level='INFO') # Once the file is too old, it's rotated
|
|
||||||
logger.add(sys.stderr, diagnose=True)
|
|
||||||
logger.add(sys.stdout)
|
|
||||||
logger.enable("")
|
|
||||||
|
|
||||||
#
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
from transformers import logging
|
from transformers import logging
|
||||||
@ -112,6 +118,8 @@ mimetypes.add_type('application/javascript', '.js')
|
|||||||
opt_C = 4
|
opt_C = 4
|
||||||
opt_f = 8
|
opt_f = 8
|
||||||
|
|
||||||
|
# The model manager loads and unloads the SD models and has features to download them or find their location
|
||||||
|
#model_manager = ModelManager()
|
||||||
|
|
||||||
def load_configs():
|
def load_configs():
|
||||||
if not "defaults" in st.session_state:
|
if not "defaults" in st.session_state:
|
||||||
@ -269,6 +277,33 @@ def make_grid(n_items=5, n_cols=5):
|
|||||||
|
|
||||||
return cols
|
return cols
|
||||||
|
|
||||||
|
|
||||||
|
def merge(file1, file2, out, weight):
|
||||||
|
alpha = (weight)/100
|
||||||
|
if not(file1.endswith(".ckpt")):
|
||||||
|
file1 += ".ckpt"
|
||||||
|
if not(file2.endswith(".ckpt")):
|
||||||
|
file2 += ".ckpt"
|
||||||
|
if not(out.endswith(".ckpt")):
|
||||||
|
out += ".ckpt"
|
||||||
|
#Load Models
|
||||||
|
model_0 = torch.load(file1)
|
||||||
|
model_1 = torch.load(file2)
|
||||||
|
theta_0 = model_0['state_dict']
|
||||||
|
theta_1 = model_1['state_dict']
|
||||||
|
|
||||||
|
for key in theta_0.keys():
|
||||||
|
if 'model' in key and key in theta_1:
|
||||||
|
theta_0[key] = (alpha) * theta_0[key] + (1-alpha) * theta_1[key]
|
||||||
|
|
||||||
|
logger.info("RUNNING...\n(STAGE 2)")
|
||||||
|
|
||||||
|
for key in theta_1.keys():
|
||||||
|
if 'model' in key and key not in theta_0:
|
||||||
|
theta_0[key] = theta_1[key]
|
||||||
|
torch.save(model_0, out)
|
||||||
|
|
||||||
|
|
||||||
def human_readable_size(size, decimal_places=3):
|
def human_readable_size(size, decimal_places=3):
|
||||||
"""Return a human readable size from bytes."""
|
"""Return a human readable size from bytes."""
|
||||||
for unit in ['B','KB','MB','GB','TB']:
|
for unit in ['B','KB','MB','GB','TB']:
|
||||||
@ -282,6 +317,8 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m
|
|||||||
CustomModel_available=False, custom_model="Stable Diffusion v1.5"):
|
CustomModel_available=False, custom_model="Stable Diffusion v1.5"):
|
||||||
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
|
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
|
||||||
|
|
||||||
|
#model_manager.init()
|
||||||
|
|
||||||
logger.info("Loading models.")
|
logger.info("Loading models.")
|
||||||
|
|
||||||
if "progress_bar_text" in st.session_state:
|
if "progress_bar_text" in st.session_state:
|
||||||
@ -1350,6 +1387,77 @@ def load_RealESRGAN(model_name: str):
|
|||||||
|
|
||||||
return server_state['RealESRGAN']
|
return server_state['RealESRGAN']
|
||||||
|
|
||||||
|
#
|
||||||
|
class RealESRGANModel(nn.Module):
|
||||||
|
def __init__(self, model_path, tile=0, tile_pad=10, pre_pad=0, fp32=False):
|
||||||
|
super().__init__()
|
||||||
|
try:
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(
|
||||||
|
"You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
|
||||||
|
"pip install realesrgan"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||||
|
self.upsampler = RealESRGANer(
|
||||||
|
scale=4, model_path=model_path, model=model, tile=tile, tile_pad=tile_pad, pre_pad=pre_pad, half=not fp32
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image, outscale=4, convert_to_pil=True):
|
||||||
|
"""Upsample an image array or path.
|
||||||
|
Args:
|
||||||
|
image (Union[np.ndarray, str]): Either a np array or an image path. np array is assumed to be in RGB format,
|
||||||
|
and we convert it to BGR.
|
||||||
|
outscale (int, optional): Amount to upscale the image. Defaults to 4.
|
||||||
|
convert_to_pil (bool, optional): If True, return PIL image. Otherwise, return numpy array (BGR). Defaults to True.
|
||||||
|
Returns:
|
||||||
|
Union[np.ndarray, PIL.Image.Image]: An upsampled version of the input image.
|
||||||
|
"""
|
||||||
|
if isinstance(image, (str, Path)):
|
||||||
|
img = cv2.imread(image, cv2.IMREAD_UNCHANGED)
|
||||||
|
else:
|
||||||
|
img = image
|
||||||
|
img = (img * 255).round().astype("uint8")
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
|
||||||
|
image, _ = self.upsampler.enhance(img, outscale=outscale)
|
||||||
|
|
||||||
|
if convert_to_pil:
|
||||||
|
image = Image.fromarray(image[:, :, ::-1])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_name_or_path="nateraw/real-esrgan"):
|
||||||
|
"""Initialize a pretrained Real-ESRGAN upsampler.
|
||||||
|
Args:
|
||||||
|
model_name_or_path (str, optional): The Hugging Face repo ID or path to local model. Defaults to 'nateraw/real-esrgan'.
|
||||||
|
Returns:
|
||||||
|
PipelineRealESRGAN: An instance of `PipelineRealESRGAN` instantiated from pretrained model.
|
||||||
|
"""
|
||||||
|
# reuploaded form official ones mentioned here:
|
||||||
|
# https://github.com/xinntao/Real-ESRGAN
|
||||||
|
if Path(model_name_or_path).exists():
|
||||||
|
file = model_name_or_path
|
||||||
|
else:
|
||||||
|
file = hf_hub_download(model_name_or_path, "RealESRGAN_x4plus.pth")
|
||||||
|
return cls(file)
|
||||||
|
|
||||||
|
def upsample_imagefolder(self, in_dir, out_dir, suffix="out", outfile_ext=".png"):
|
||||||
|
in_dir, out_dir = Path(in_dir), Path(out_dir)
|
||||||
|
if not in_dir.exists():
|
||||||
|
raise FileNotFoundError(f"Provided input directory {in_dir} does not exist")
|
||||||
|
|
||||||
|
out_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
image_paths = [x for x in in_dir.glob("*") if x.suffix.lower() in [".png", ".jpg", ".jpeg"]]
|
||||||
|
for image in image_paths:
|
||||||
|
im = self(str(image))
|
||||||
|
out_filepath = out_dir / (image.stem + suffix + outfile_ext)
|
||||||
|
im.save(out_filepath)
|
||||||
|
|
||||||
#
|
#
|
||||||
@retry(tries=5)
|
@retry(tries=5)
|
||||||
def load_LDSR(model_name="model", config="project", checking=False):
|
def load_LDSR(model_name="model", config="project", checking=False):
|
||||||
@ -1744,6 +1852,9 @@ def seed_to_int(s):
|
|||||||
if s is None or s == '':
|
if s is None or s == '':
|
||||||
return random.randint(0, 2**32 - 1)
|
return random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
|
if ',' in s:
|
||||||
|
s = s.split(',')
|
||||||
|
|
||||||
if type(s) is list:
|
if type(s) is list:
|
||||||
seed_list = []
|
seed_list = []
|
||||||
for seed in s:
|
for seed in s:
|
||||||
@ -1955,6 +2066,7 @@ def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, widt
|
|||||||
|
|
||||||
filename_i = os.path.join(sample_path_i, filename)
|
filename_i = os.path.join(sample_path_i, filename)
|
||||||
|
|
||||||
|
if "defaults" in st.session_state:
|
||||||
if st.session_state['defaults'].general.save_metadata or write_info_files:
|
if st.session_state['defaults'].general.save_metadata or write_info_files:
|
||||||
# toggles differ for txt2img vs. img2img:
|
# toggles differ for txt2img vs. img2img:
|
||||||
offset = 0 if init_img is None else 2
|
offset = 0 if init_img is None else 2
|
||||||
@ -2563,7 +2675,7 @@ def process_images(
|
|||||||
#output_images.append(image)
|
#output_images.append(image)
|
||||||
#if simple_templating:
|
#if simple_templating:
|
||||||
#grid_captions.append( captions[i] )
|
#grid_captions.append( captions[i] )
|
||||||
|
if "defaults" in st.session_state:
|
||||||
if st.session_state['defaults'].general.optimized:
|
if st.session_state['defaults'].general.optimized:
|
||||||
mem = torch.cuda.memory_allocated()/1e6
|
mem = torch.cuda.memory_allocated()/1e6
|
||||||
server_state["modelFS"].to("cpu")
|
server_state["modelFS"].to("cpu")
|
||||||
|
@ -10,7 +10,6 @@ import time
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import ModelMixin
|
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
@ -22,59 +21,39 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .upsampling import RealESRGANModel
|
from sd_utils import RealESRGANModel
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def get_spec_norm(wav, sr, n_mels=512, hop_length=704):
|
def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=1.0, smooth=0.0):
|
||||||
"""Obtain maximum value for each time-frame in Mel Spectrogram,
|
y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
|
||||||
and normalize between 0 and 1
|
|
||||||
|
|
||||||
Borrowed from lucid sonic dreams repo. In there, they programatically determine hop length
|
# librosa.stft hardcoded defaults...
|
||||||
but I really didn't understand what was going on so I removed it and hard coded the output.
|
# n_fft defaults to 2048
|
||||||
"""
|
# hop length is win_length // 4
|
||||||
|
# win_length defaults to n_fft
|
||||||
|
D = librosa.stft(y, n_fft=2048, hop_length=2048 // 4, win_length=2048)
|
||||||
|
|
||||||
# Generate Mel Spectrogram
|
# Extract percussive elements
|
||||||
spec_raw = librosa.feature.melspectrogram(y=wav, sr=sr, n_mels=n_mels, hop_length=hop_length)
|
D_harmonic, D_percussive = librosa.decompose.hpss(D, margin=margin)
|
||||||
|
y_percussive = librosa.istft(D_percussive, length=len(y))
|
||||||
|
|
||||||
# Obtain maximum value per time-frame
|
# Get normalized melspectrogram
|
||||||
|
spec_raw = librosa.feature.melspectrogram(y=y_percussive, sr=sr)
|
||||||
spec_max = np.amax(spec_raw, axis=0)
|
spec_max = np.amax(spec_raw, axis=0)
|
||||||
|
|
||||||
# Normalize all values between 0 and 1
|
|
||||||
spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max)
|
spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max)
|
||||||
|
|
||||||
return spec_norm
|
# Resize cumsum of spec norm to our desired number of interpolation frames
|
||||||
|
x_norm = np.linspace(0, spec_norm.shape[-1], spec_norm.shape[-1])
|
||||||
|
y_norm = np.cumsum(spec_norm)
|
||||||
|
y_norm /= y_norm[-1]
|
||||||
|
x_resize = np.linspace(0, y_norm.shape[-1], int(duration*fps))
|
||||||
|
|
||||||
|
T = np.interp(x_resize, x_norm, y_norm)
|
||||||
|
|
||||||
def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=(1.0, 5.0)):
|
# Apply smoothing
|
||||||
"""Get the array that will be used to determine how much to interpolate between images.
|
return T * (1 - smooth) + np.linspace(0.0, 1.0, T.shape[0]) * smooth
|
||||||
|
|
||||||
Normally, this is just a linspace between 0 and 1 for the number of frames to generate. In this case,
|
|
||||||
we want to use the amplitude of the audio to determine how much to interpolate between images.
|
|
||||||
|
|
||||||
So, here we:
|
|
||||||
1. Load the audio file
|
|
||||||
2. Split the audio into harmonic and percussive components
|
|
||||||
3. Get the normalized amplitude of the percussive component, resized to the number of frames
|
|
||||||
4. Get the cumulative sum of the amplitude array
|
|
||||||
5. Normalize the cumulative sum between 0 and 1
|
|
||||||
6. Return the array
|
|
||||||
|
|
||||||
I honestly have no clue what I'm doing here. Suggestions welcome.
|
|
||||||
"""
|
|
||||||
y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
|
|
||||||
wav_harmonic, wav_percussive = librosa.effects.hpss(y, margin=margin)
|
|
||||||
|
|
||||||
# Apparently n_mels is supposed to be input shape but I don't think it matters here?
|
|
||||||
frame_duration = int(sr / fps)
|
|
||||||
wav_norm = get_spec_norm(wav_percussive, sr, n_mels=512, hop_length=frame_duration)
|
|
||||||
amplitude_arr = np.resize(wav_norm, int(duration * fps))
|
|
||||||
T = np.cumsum(amplitude_arr)
|
|
||||||
T /= T[-1]
|
|
||||||
T[0] = 0.0
|
|
||||||
return T
|
|
||||||
|
|
||||||
|
|
||||||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||||
@ -130,7 +109,6 @@ def make_video_pyav(
|
|||||||
frame = pil_to_tensor(Image.open(img)).unsqueeze(0)
|
frame = pil_to_tensor(Image.open(img)).unsqueeze(0)
|
||||||
frames = frame if frames is None else torch.cat([frames, frame])
|
frames = frame if frames is None else torch.cat([frames, frame])
|
||||||
else:
|
else:
|
||||||
|
|
||||||
frames = frames_or_frame_dir
|
frames = frames_or_frame_dir
|
||||||
|
|
||||||
# TCHW -> THWC
|
# TCHW -> THWC
|
||||||
@ -208,6 +186,16 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
new_config["steps_offset"] = 1
|
new_config["steps_offset"] = 1
|
||||||
scheduler._internal_dict = FrozenDict(new_config)
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if safety_checker is None:
|
||||||
|
logger.warn(
|
||||||
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||||
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||||
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||||
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||||
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||||
|
)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@ -251,6 +239,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
width: int = 512,
|
width: int = 512,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float = 7.5,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
eta: float = 0.0,
|
eta: float = 0.0,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
@ -259,12 +249,13 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||||
callback_steps: Optional[int] = 1,
|
callback_steps: Optional[int] = 1,
|
||||||
text_embeddings: Optional[torch.FloatTensor] = None,
|
text_embeddings: Optional[torch.FloatTensor] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
Args:
|
Args:
|
||||||
prompt (`str` or `List[str]`):
|
prompt (`str` or `List[str]`, *optional*, defaults to `None`):
|
||||||
The prompt or prompts to guide the image generation.
|
The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
|
||||||
height (`int`, *optional*, defaults to 512):
|
height (`int`, *optional*, defaults to 512):
|
||||||
The height in pixels of the generated image.
|
The height in pixels of the generated image.
|
||||||
width (`int`, *optional*, defaults to 512):
|
width (`int`, *optional*, defaults to 512):
|
||||||
@ -278,6 +269,11 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
usually at the expense of lower image quality.
|
usually at the expense of lower image quality.
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||||
|
if `guidance_scale` is less than `1`).
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
eta (`float`, *optional*, defaults to 0.0):
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||||
@ -300,8 +296,10 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
callback_steps (`int`, *optional*, defaults to 1):
|
callback_steps (`int`, *optional*, defaults to 1):
|
||||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||||
called at every step.
|
called at every step.
|
||||||
text_embeddings(`torch.FloatTensor`, *optional*):
|
text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):
|
||||||
Pre-generated text embeddings.
|
Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
|
||||||
|
`prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
|
||||||
|
the supplied `prompt`.
|
||||||
Returns:
|
Returns:
|
||||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||||
@ -340,7 +338,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||||
logger.warning(
|
print(
|
||||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||||
)
|
)
|
||||||
@ -349,21 +347,51 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
batch_size = text_embeddings.shape[0]
|
batch_size = text_embeddings.shape[0]
|
||||||
|
|
||||||
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
|
bs_embed, seq_len, _ = text_embeddings.shape
|
||||||
|
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||||
|
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
# get unconditional embeddings for classifier free guidance
|
# get unconditional embeddings for classifier free guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
# HACK - Not setting text_input_ids here when walking, so hard coding to max length of tokenizer
|
uncond_tokens: List[str]
|
||||||
# TODO - Determine if this is OK to do
|
if negative_prompt is None:
|
||||||
# max_length = text_input_ids.shape[-1]
|
uncond_tokens = [""]
|
||||||
|
elif type(prompt) is not type(negative_prompt):
|
||||||
|
raise TypeError(
|
||||||
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
|
f" {type(prompt)}."
|
||||||
|
)
|
||||||
|
elif isinstance(negative_prompt, str):
|
||||||
|
uncond_tokens = [negative_prompt]
|
||||||
|
elif batch_size != len(negative_prompt):
|
||||||
|
raise ValueError(
|
||||||
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||||
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||||
|
" the batch size of `prompt`."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
uncond_tokens = negative_prompt
|
||||||
|
|
||||||
max_length = self.tokenizer.model_max_length
|
max_length = self.tokenizer.model_max_length
|
||||||
uncond_input = self.tokenizer(
|
uncond_input = self.tokenizer(
|
||||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
uncond_tokens,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||||
|
|
||||||
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||||
|
seq_len = uncond_embeddings.shape[1]
|
||||||
|
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||||
|
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
# For classifier free guidance, we need to do two forward passes.
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
# to avoid doing two forward passes
|
# to avoid doing two forward passes
|
||||||
@ -374,19 +402,20 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
# Unlike in other pipelines, latents need to be generated in the target device
|
# Unlike in other pipelines, latents need to be generated in the target device
|
||||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||||
# However this currently doesn't work in `mps`.
|
# However this currently doesn't work in `mps`.
|
||||||
latents_device = "cpu" if self.device.type == "mps" else self.device
|
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
latents_dtype = text_embeddings.dtype
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = torch.randn(
|
if self.device.type == "mps":
|
||||||
latents_shape,
|
# randn does not exist on mps
|
||||||
generator=generator,
|
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
||||||
device=latents_device,
|
self.device
|
||||||
dtype=text_embeddings.dtype,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||||
else:
|
else:
|
||||||
if latents.shape != latents_shape:
|
if latents.shape != latents_shape:
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||||
latents = latents.to(latents_device)
|
latents = latents.to(self.device)
|
||||||
|
|
||||||
# set timesteps
|
# set timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
@ -431,12 +460,19 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
image = self.vae.decode(latents).sample
|
image = self.vae.decode(latents).sample
|
||||||
|
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
|
||||||
|
|
||||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||||
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
|
||||||
|
if self.safety_checker is not None:
|
||||||
|
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
image, has_nsfw_concept = self.safety_checker(
|
image, has_nsfw_concept = self.safety_checker(
|
||||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
has_nsfw_concept = None
|
||||||
|
|
||||||
if output_type == "pil":
|
if output_type == "pil":
|
||||||
image = self.numpy_to_pil(image)
|
image = self.numpy_to_pil(image)
|
||||||
@ -449,16 +485,9 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size):
|
def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size):
|
||||||
embeds_a = self.embed_text(prompt_a)
|
embeds_a = self.embed_text(prompt_a)
|
||||||
embeds_b = self.embed_text(prompt_b)
|
embeds_b = self.embed_text(prompt_b)
|
||||||
latents_a = torch.randn(
|
|
||||||
noise_shape,
|
latents_a = self.init_noise(seed_a, noise_shape)
|
||||||
device=self.device,
|
latents_b = self.init_noise(seed_b, noise_shape)
|
||||||
generator=torch.Generator(device=self.device).manual_seed(seed_a),
|
|
||||||
)
|
|
||||||
latents_b = torch.randn(
|
|
||||||
noise_shape,
|
|
||||||
device=self.device,
|
|
||||||
generator=torch.Generator(device=self.device).manual_seed(seed_b),
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_idx = 0
|
batch_idx = 0
|
||||||
embeds_batch, noise_batch = None, None
|
embeds_batch, noise_batch = None, None
|
||||||
@ -477,7 +506,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
embeds_batch, noise_batch = None, None
|
embeds_batch, noise_batch = None, None
|
||||||
|
|
||||||
def generate_interpolation_clip(
|
def make_clip_frames(
|
||||||
self,
|
self,
|
||||||
prompt_a: str,
|
prompt_a: str,
|
||||||
prompt_b: str,
|
prompt_b: str,
|
||||||
@ -530,7 +559,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
eta=eta,
|
eta=eta,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
output_type="pil" if not upsample else "numpy",
|
output_type="pil" if not upsample else "numpy",
|
||||||
)["sample"]
|
)["images"]
|
||||||
|
|
||||||
for image in outputs:
|
for image in outputs:
|
||||||
frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index)
|
frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index)
|
||||||
@ -557,6 +586,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
resume: Optional[bool] = False,
|
resume: Optional[bool] = False,
|
||||||
audio_filepath: str = None,
|
audio_filepath: str = None,
|
||||||
audio_start_sec: Optional[Union[int, float]] = None,
|
audio_start_sec: Optional[Union[int, float]] = None,
|
||||||
|
margin: Optional[float] = 1.0,
|
||||||
|
smooth: Optional[float] = 0.0,
|
||||||
):
|
):
|
||||||
"""Generate a video from a sequence of prompts and seeds. Optionally, add audio to the
|
"""Generate a video from a sequence of prompts and seeds. Optionally, add audio to the
|
||||||
video to interpolate to the intensity of the audio.
|
video to interpolate to the intensity of the audio.
|
||||||
@ -603,13 +634,17 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
Optional path to an audio file to influence the interpolation rate.
|
Optional path to an audio file to influence the interpolation rate.
|
||||||
audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0):
|
audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0):
|
||||||
Global start time of the provided audio_filepath.
|
Global start time of the provided audio_filepath.
|
||||||
|
margin (Optional[float], *optional*, defaults to 1.0):
|
||||||
|
Margin from librosa hpss to use for audio interpolation.
|
||||||
|
smooth (Optional[float], *optional*, defaults to 0.0):
|
||||||
|
Smoothness of the audio interpolation. 1.0 means linear interpolation.
|
||||||
|
|
||||||
This function will create sub directories for each prompt and seed pair.
|
This function will create sub directories for each prompt and seed pair.
|
||||||
|
|
||||||
For example, if you provide the following prompts and seeds:
|
For example, if you provide the following prompts and seeds:
|
||||||
|
|
||||||
```
|
```
|
||||||
prompts = ['a', 'b', 'c']
|
prompts = ['a dog', 'a cat', 'a bird']
|
||||||
seeds = [1, 2, 3]
|
seeds = [1, 2, 3]
|
||||||
num_interpolation_steps = 5
|
num_interpolation_steps = 5
|
||||||
output_dir = 'output_dir'
|
output_dir = 'output_dir'
|
||||||
@ -722,7 +757,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
|
audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
|
||||||
audio_duration = num_step / fps
|
audio_duration = num_step / fps
|
||||||
|
|
||||||
self.generate_interpolation_clip(
|
self.make_clip_frames(
|
||||||
prompt_a,
|
prompt_a,
|
||||||
prompt_b,
|
prompt_b,
|
||||||
seed_a,
|
seed_a,
|
||||||
@ -742,7 +777,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
offset=audio_offset,
|
offset=audio_offset,
|
||||||
duration=audio_duration,
|
duration=audio_duration,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
margin=(1.0, 5.0),
|
margin=margin,
|
||||||
|
smooth=smooth,
|
||||||
)
|
)
|
||||||
if audio_filepath
|
if audio_filepath
|
||||||
else None,
|
else None,
|
||||||
@ -783,6 +819,23 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
def init_noise(self, seed, noise_shape):
|
||||||
|
"""Helper to initialize noise"""
|
||||||
|
# randn does not exist on mps, so we create noise on CPU here and move it to the device after initialization
|
||||||
|
if self.device.type == "mps":
|
||||||
|
noise = torch.randn(
|
||||||
|
noise_shape,
|
||||||
|
device='cpu',
|
||||||
|
generator=torch.Generator(device='cpu').manual_seed(seed),
|
||||||
|
).to(self.device)
|
||||||
|
else:
|
||||||
|
noise = torch.randn(
|
||||||
|
noise_shape,
|
||||||
|
device=self.device,
|
||||||
|
generator=torch.Generator(device=self.device).manual_seed(seed),
|
||||||
|
)
|
||||||
|
return noise
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, *args, tiled=False, **kwargs):
|
def from_pretrained(cls, *args, tiled=False, **kwargs):
|
||||||
"""Same as diffusers `from_pretrained` but with tiled option, which makes images tilable"""
|
"""Same as diffusers `from_pretrained` but with tiled option, which makes images tilable"""
|
||||||
@ -799,15 +852,6 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
patch_conv(padding_mode="circular")
|
patch_conv(padding_mode="circular")
|
||||||
|
|
||||||
return super().from_pretrained(*args, **kwargs)
|
pipeline = super().from_pretrained(*args, **kwargs)
|
||||||
|
pipeline.tiled = tiled
|
||||||
|
return pipeline
|
||||||
class NoCheck(ModelMixin):
|
|
||||||
"""Can be used in place of safety checker. Use responsibly and at your own risk."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.register_parameter(name="asdf", param=torch.nn.Parameter(torch.randn(3)))
|
|
||||||
|
|
||||||
def forward(self, images=None, **kwargs):
|
|
||||||
return images, [False]
|
|
@ -410,7 +410,7 @@ def layout():
|
|||||||
sygil_suggestions.suggestion_area(placeholder)
|
sygil_suggestions.suggestion_area(placeholder)
|
||||||
|
|
||||||
# creating the page layout using columns
|
# creating the page layout using columns
|
||||||
col1, col2, col3 = st.columns([1,2,1], gap="large")
|
col1, col2, col3 = st.columns([2,5,2], gap="large")
|
||||||
|
|
||||||
with col1:
|
with col1:
|
||||||
width = st.slider("Width:", min_value=st.session_state['defaults'].txt2img.width.min_value, max_value=st.session_state['defaults'].txt2img.width.max_value,
|
width = st.slider("Width:", min_value=st.session_state['defaults'].txt2img.width.min_value, max_value=st.session_state['defaults'].txt2img.width.max_value,
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -125,8 +125,11 @@ def layout():
|
|||||||
{'id': 'Stable Diffusion', 'label': 'Stable Diffusion', 'icon': 'bi bi-grid-1x2-fill'},
|
{'id': 'Stable Diffusion', 'label': 'Stable Diffusion', 'icon': 'bi bi-grid-1x2-fill'},
|
||||||
{'id': 'Textual Inversion', 'label': 'Textual Inversion', 'icon': 'bi bi-lightbulb-fill'},
|
{'id': 'Textual Inversion', 'label': 'Textual Inversion', 'icon': 'bi bi-lightbulb-fill'},
|
||||||
{'id': 'Model Manager', 'label': 'Model Manager', 'icon': 'bi bi-cloud-arrow-down-fill'},
|
{'id': 'Model Manager', 'label': 'Model Manager', 'icon': 'bi bi-cloud-arrow-down-fill'},
|
||||||
#{'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
|
{'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
|
||||||
{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
|
{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
|
||||||
|
#{'id': 'Barfi/BaklavaJS', 'label': 'Barfi/BaklavaJS', 'icon': 'bi bi-diagram-3-fill'},
|
||||||
|
#{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
|
||||||
|
]},
|
||||||
{'id': 'Settings', 'label': 'Settings', 'icon': 'bi bi-gear-fill'},
|
{'id': 'Settings', 'label': 'Settings', 'icon': 'bi bi-gear-fill'},
|
||||||
#{'icon': "fa-solid fa-radar",'label':"Dropdown1", 'submenu':[
|
#{'icon': "fa-solid fa-radar",'label':"Dropdown1", 'submenu':[
|
||||||
# {'id':' subid11','icon': "fa fa-paperclip", 'label':"Sub-item 1"},{'id':'subid12','icon': "💀", 'label':"Sub-item 2"},{'id':'subid13','icon': "fa fa-database", 'label':"Sub-item 3"}]},
|
# {'id':' subid11','icon': "fa fa-paperclip", 'label':"Sub-item 1"},{'id':'subid12','icon': "💀", 'label':"Sub-item 2"},{'id':'subid13','icon': "fa fa-database", 'label':"Sub-item 3"}]},
|
||||||
@ -172,6 +175,10 @@ def layout():
|
|||||||
#horizontal_orientation=False,
|
#horizontal_orientation=False,
|
||||||
#override_theme={'txc_inactive': 'white','menu_background':'#111', 'stVerticalBlock': '#111','txc_active':'yellow','option_active':'blue'})
|
#override_theme={'txc_inactive': 'white','menu_background':'#111', 'stVerticalBlock': '#111','txc_active':'yellow','option_active':'blue'})
|
||||||
|
|
||||||
|
#
|
||||||
|
#if menu_id == "Home":
|
||||||
|
#st.info("Under Construction. :construction_worker:")
|
||||||
|
|
||||||
if menu_id == "Stable Diffusion":
|
if menu_id == "Stable Diffusion":
|
||||||
# set the page url and title
|
# set the page url and title
|
||||||
#st.experimental_set_query_params(page='stable-diffusion')
|
#st.experimental_set_query_params(page='stable-diffusion')
|
||||||
@ -227,6 +234,11 @@ def layout():
|
|||||||
from APIServer import layout
|
from APIServer import layout
|
||||||
layout()
|
layout()
|
||||||
|
|
||||||
|
#elif menu_id == 'Barfi/BaklavaJS':
|
||||||
|
#set_page_title("Barfi/BaklavaJS - Stable Diffusion Playground")
|
||||||
|
#from barfi_baklavajs import layout
|
||||||
|
#layout()
|
||||||
|
|
||||||
elif menu_id == 'Settings':
|
elif menu_id == 'Settings':
|
||||||
set_page_title("Settings - Stable Diffusion Playground")
|
set_page_title("Settings - Stable Diffusion Playground")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user