mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-17 10:12:04 +03:00
Merge branch 'master' of github.com:cmdr2/hlky-webui
This commit is contained in:
commit
a00eda9b32
@ -56,11 +56,11 @@ timm==0.6.7
|
|||||||
tqdm==4.64.0
|
tqdm==4.64.0
|
||||||
tensorboard==2.10.1
|
tensorboard==2.10.1
|
||||||
|
|
||||||
|
|
||||||
# Other
|
# Other
|
||||||
retry==0.9.2 # used by sd_utils
|
retry==0.9.2 # used by sd_utils
|
||||||
python-slugify==6.1.2 # used by sd_utils
|
python-slugify==6.1.2 # used by sd_utils
|
||||||
piexif==1.1.3 # used by sd_utils
|
piexif==1.1.3 # used by sd_utils
|
||||||
|
pywebview==3.6.3 # used by streamlit_webview.py
|
||||||
|
|
||||||
accelerate==0.12.0
|
accelerate==0.12.0
|
||||||
albumentations==0.4.3
|
albumentations==0.4.3
|
||||||
|
@ -45,12 +45,12 @@ def download_file(file_name, file_path, file_url):
|
|||||||
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
|
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token), stream=True) as r:
|
with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token) if "huggingface.co" in file_url else None, stream=True) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
with open(os.path.join(file_path, file_name), 'wb') as f:
|
with open(os.path.join(file_path, file_name), 'wb') as f:
|
||||||
for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"):
|
for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"):
|
||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
except HTTPError:
|
except HTTPError as e:
|
||||||
if "huggingface.co" in file_url:
|
if "huggingface.co" in file_url:
|
||||||
if "resolve"in file_url:
|
if "resolve"in file_url:
|
||||||
repo_url = file_url.split("resolve")[0]
|
repo_url = file_url.split("resolve")[0]
|
||||||
@ -59,9 +59,12 @@ def download_file(file_name, file_path, file_url):
|
|||||||
f"You need to accept the license for the model in order to be able to download it. "
|
f"You need to accept the license for the model in order to be able to download it. "
|
||||||
f"Please visit {repo_url} and accept the lincense there, then try again to download the model.")
|
f"Please visit {repo_url} and accept the lincense there, then try again to download the model.")
|
||||||
|
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(file_name + ' already exists.')
|
print(file_name + ' already exists.')
|
||||||
|
|
||||||
|
|
||||||
def download_model(models, model_name):
|
def download_model(models, model_name):
|
||||||
""" Download all files from model_list[model_name] """
|
""" Download all files from model_list[model_name] """
|
||||||
for file in models[model_name]:
|
for file in models[model_name]:
|
||||||
|
@ -66,6 +66,9 @@ st.session_state["log"] = []
|
|||||||
|
|
||||||
def load_blip_model():
|
def load_blip_model():
|
||||||
logger.info("Loading BLIP Model")
|
logger.info("Loading BLIP Model")
|
||||||
|
if "log" not in st.session_state:
|
||||||
|
st.session_state["log"] = []
|
||||||
|
|
||||||
st.session_state["log"].append("Loading BLIP Model")
|
st.session_state["log"].append("Loading BLIP Model")
|
||||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||||
|
|
||||||
@ -232,7 +235,7 @@ def interrogate(image, models):
|
|||||||
|
|
||||||
for best in bests:
|
for best in bests:
|
||||||
best.sort(key=lambda x: x[1], reverse=True)
|
best.sort(key=lambda x: x[1], reverse=True)
|
||||||
# prune to 3
|
# prune to 3
|
||||||
best = best[:3]
|
best = best[:3]
|
||||||
|
|
||||||
row = [model_name]
|
row = [model_name]
|
||||||
@ -326,7 +329,7 @@ def img2txt():
|
|||||||
def layout():
|
def layout():
|
||||||
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
|
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
|
||||||
#st.info("Under Construction. :construction_worker:")
|
#st.info("Under Construction. :construction_worker:")
|
||||||
#
|
#
|
||||||
if "clip_models" not in server_state:
|
if "clip_models" not in server_state:
|
||||||
server_state["clip_models"] = {}
|
server_state["clip_models"] = {}
|
||||||
if "preprocesses" not in server_state:
|
if "preprocesses" not in server_state:
|
||||||
@ -397,7 +400,9 @@ def layout():
|
|||||||
with col2:
|
with col2:
|
||||||
st.subheader("Image")
|
st.subheader("Image")
|
||||||
|
|
||||||
refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
|
image_col1, image_col2 = st.columns([10,25])
|
||||||
|
with image_col1:
|
||||||
|
refresh = st.form_submit_button("Update Preview Image", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
|
||||||
|
|
||||||
if st.session_state["uploaded_image"]:
|
if st.session_state["uploaded_image"]:
|
||||||
#print (type(st.session_state["uploaded_image"]))
|
#print (type(st.session_state["uploaded_image"]))
|
||||||
@ -436,11 +441,12 @@ def layout():
|
|||||||
#st.session_state["input_image_preview"].code('', language="")
|
#st.session_state["input_image_preview"].code('', language="")
|
||||||
st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
|
st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
|
||||||
|
|
||||||
#
|
with image_col2:
|
||||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
#
|
||||||
# generate_col1.title("")
|
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||||
# generate_col1.title("")
|
# generate_col1.title("")
|
||||||
generate_button = st.form_submit_button("Generate!")
|
# generate_col1.title("")
|
||||||
|
generate_button = st.form_submit_button("Generate!", help="Start interrogating the images to generate a prompt from each of the selected images")
|
||||||
|
|
||||||
if generate_button:
|
if generate_button:
|
||||||
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
|
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
|
||||||
|
0
scripts/nataili/__init__.py
Normal file
0
scripts/nataili/__init__.py
Normal file
0
scripts/nataili/inference/__init__.py
Normal file
0
scripts/nataili/inference/__init__.py
Normal file
0
scripts/nataili/inference/compvis/__init__.py
Normal file
0
scripts/nataili/inference/compvis/__init__.py
Normal file
551
scripts/nataili/inference/compvis/img2img.py
Normal file
551
scripts/nataili/inference/compvis/img2img.py
Normal file
@ -0,0 +1,551 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import k_diffusion as K
|
||||||
|
import tqdm
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
import skimage
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.kdiffusion import CFGMaskedDenoiser, KDiffusionSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from nataili.util.cache import torch_gc
|
||||||
|
from nataili.util.check_prompt_length import check_prompt_length
|
||||||
|
from nataili.util.get_next_sequence_number import get_next_sequence_number
|
||||||
|
from nataili.util.image_grid import image_grid
|
||||||
|
from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip
|
||||||
|
from nataili.util.save_sample import save_sample
|
||||||
|
from nataili.util.seed_to_int import seed_to_int
|
||||||
|
from slugify import slugify
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
class img2img:
|
||||||
|
def __init__(self, model, device, output_dir, save_extension='jpg',
|
||||||
|
output_file_path=False, load_concepts=False, concepts_dir=None,
|
||||||
|
verify_input=True, auto_cast=True):
|
||||||
|
self.model = model
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.output_file_path = output_file_path
|
||||||
|
self.save_extension = save_extension
|
||||||
|
self.load_concepts = load_concepts
|
||||||
|
self.concepts_dir = concepts_dir
|
||||||
|
self.verify_input = verify_input
|
||||||
|
self.auto_cast = auto_cast
|
||||||
|
self.device = device
|
||||||
|
self.comments = []
|
||||||
|
self.output_images = []
|
||||||
|
self.info = ''
|
||||||
|
self.stats = ''
|
||||||
|
self.images = []
|
||||||
|
|
||||||
|
def create_random_tensors(self, shape, seeds):
|
||||||
|
xs = []
|
||||||
|
for seed in seeds:
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
# randn results depend on device; gpu and cpu get different results for same seed;
|
||||||
|
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
||||||
|
# but the original script had it like this so i do not dare change it for now because
|
||||||
|
# it will break everyone's seeds.
|
||||||
|
xs.append(torch.randn(shape, device=self.device))
|
||||||
|
x = torch.stack(xs)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def process_prompt_tokens(self, prompt_tokens):
|
||||||
|
# compviz codebase
|
||||||
|
tokenizer = self.model.cond_stage_model.tokenizer
|
||||||
|
text_encoder = self.model.cond_stage_model.transformer
|
||||||
|
|
||||||
|
# diffusers codebase
|
||||||
|
#tokenizer = pipe.tokenizer
|
||||||
|
#text_encoder = pipe.text_encoder
|
||||||
|
|
||||||
|
ext = ('.pt', '.bin')
|
||||||
|
for token_name in prompt_tokens:
|
||||||
|
embedding_path = os.path.join(self.concepts_dir, token_name)
|
||||||
|
if os.path.exists(embedding_path):
|
||||||
|
for files in os.listdir(embedding_path):
|
||||||
|
if files.endswith(ext):
|
||||||
|
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
|
||||||
|
else:
|
||||||
|
print(f"Concept {token_name} not found in {self.concepts_dir}")
|
||||||
|
del tokenizer, text_encoder
|
||||||
|
return
|
||||||
|
del tokenizer, text_encoder
|
||||||
|
|
||||||
|
def resize_image(self, resize_mode, im, width, height):
|
||||||
|
LANCZOS = (PIL.Image.Resampling.LANCZOS if hasattr(PIL.Image, 'Resampling') else PIL.Image.LANCZOS)
|
||||||
|
if resize_mode == "resize":
|
||||||
|
res = im.resize((width, height), resample=LANCZOS)
|
||||||
|
elif resize_mode == "crop":
|
||||||
|
ratio = width / height
|
||||||
|
src_ratio = im.width / im.height
|
||||||
|
|
||||||
|
src_w = width if ratio > src_ratio else im.width * height // im.height
|
||||||
|
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
||||||
|
|
||||||
|
resized = im.resize((src_w, src_h), resample=LANCZOS)
|
||||||
|
res = PIL.Image.new("RGBA", (width, height))
|
||||||
|
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||||
|
else:
|
||||||
|
ratio = width / height
|
||||||
|
src_ratio = im.width / im.height
|
||||||
|
|
||||||
|
src_w = width if ratio < src_ratio else im.width * height // im.height
|
||||||
|
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
||||||
|
|
||||||
|
resized = im.resize((src_w, src_h), resample=LANCZOS)
|
||||||
|
res = PIL.Image.new("RGBA", (width, height))
|
||||||
|
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||||
|
|
||||||
|
if ratio < src_ratio:
|
||||||
|
fill_height = height // 2 - src_h // 2
|
||||||
|
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
||||||
|
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
||||||
|
elif ratio > src_ratio:
|
||||||
|
fill_width = width // 2 - src_w // 2
|
||||||
|
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
||||||
|
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
#
|
||||||
|
# helper fft routines that keep ortho normalization and auto-shift before and after fft
|
||||||
|
def _fft2(self, data):
|
||||||
|
if data.ndim > 2: # has channels
|
||||||
|
out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
|
||||||
|
for c in range(data.shape[2]):
|
||||||
|
c_data = data[:,:,c]
|
||||||
|
out_fft[:,:,c] = np.fft.fft2(np.fft.fftshift(c_data),norm="ortho")
|
||||||
|
out_fft[:,:,c] = np.fft.ifftshift(out_fft[:,:,c])
|
||||||
|
else: # one channel
|
||||||
|
out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
|
||||||
|
out_fft[:,:] = np.fft.fft2(np.fft.fftshift(data),norm="ortho")
|
||||||
|
out_fft[:,:] = np.fft.ifftshift(out_fft[:,:])
|
||||||
|
|
||||||
|
return out_fft
|
||||||
|
|
||||||
|
def _ifft2(self, data):
|
||||||
|
if data.ndim > 2: # has channels
|
||||||
|
out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
|
||||||
|
for c in range(data.shape[2]):
|
||||||
|
c_data = data[:,:,c]
|
||||||
|
out_ifft[:,:,c] = np.fft.ifft2(np.fft.fftshift(c_data),norm="ortho")
|
||||||
|
out_ifft[:,:,c] = np.fft.ifftshift(out_ifft[:,:,c])
|
||||||
|
else: # one channel
|
||||||
|
out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
|
||||||
|
out_ifft[:,:] = np.fft.ifft2(np.fft.fftshift(data),norm="ortho")
|
||||||
|
out_ifft[:,:] = np.fft.ifftshift(out_ifft[:,:])
|
||||||
|
|
||||||
|
return out_ifft
|
||||||
|
|
||||||
|
def _get_gaussian_window(self, width, height, std=3.14, mode=0):
|
||||||
|
|
||||||
|
window_scale_x = float(width / min(width, height))
|
||||||
|
window_scale_y = float(height / min(width, height))
|
||||||
|
|
||||||
|
window = np.zeros((width, height))
|
||||||
|
x = (np.arange(width) / width * 2. - 1.) * window_scale_x
|
||||||
|
for y in range(height):
|
||||||
|
fy = (y / height * 2. - 1.) * window_scale_y
|
||||||
|
if mode == 0:
|
||||||
|
window[:, y] = np.exp(-(x**2+fy**2) * std)
|
||||||
|
else:
|
||||||
|
window[:, y] = (1/((x**2+1.) * (fy**2+1.))) ** (std/3.14) # hey wait a minute that's not gaussian
|
||||||
|
|
||||||
|
return window
|
||||||
|
|
||||||
|
def _get_masked_window_rgb(self, np_mask_grey, hardness=1.):
|
||||||
|
np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
|
||||||
|
if hardness != 1.:
|
||||||
|
hardened = np_mask_grey[:] ** hardness
|
||||||
|
else:
|
||||||
|
hardened = np_mask_grey[:]
|
||||||
|
for c in range(3):
|
||||||
|
np_mask_rgb[:,:,c] = hardened[:]
|
||||||
|
return np_mask_rgb
|
||||||
|
|
||||||
|
def get_matched_noise(self, _np_src_image, np_mask_rgb, noise_q, color_variation):
|
||||||
|
"""
|
||||||
|
Explanation:
|
||||||
|
Getting good results in/out-painting with stable diffusion can be challenging.
|
||||||
|
Although there are simpler effective solutions for in-painting, out-painting can be especially challenging because there is no color data
|
||||||
|
in the masked area to help prompt the generator. Ideally, even for in-painting we'd like work effectively without that data as well.
|
||||||
|
Provided here is my take on a potential solution to this problem.
|
||||||
|
|
||||||
|
By taking a fourier transform of the masked src img we get a function that tells us the presence and orientation of each feature scale in the unmasked src.
|
||||||
|
Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations, and positions increases output coherence
|
||||||
|
by helping keep features aligned. This technique is applicable to any continuous generation task such as audio or video, each of which can
|
||||||
|
be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased. For multi-channel data such as color
|
||||||
|
or stereo sound the "color tone" or histogram of the seed noise can be matched to improve quality (using scikit-image currently)
|
||||||
|
This method is quite robust and has the added benefit of being fast independently of the size of the out-painted area.
|
||||||
|
The effects of this method include things like helping the generator integrate the pre-existing view distance and camera angle.
|
||||||
|
|
||||||
|
Carefully managing color and brightness with histogram matching is also essential to achieving good coherence.
|
||||||
|
|
||||||
|
noise_q controls the exponent in the fall-off of the distribution can be any positive number, lower values means higher detail (range > 0, default 1.)
|
||||||
|
color_variation controls how much freedom is allowed for the colors/palette of the out-painted area (range 0..1, default 0.01)
|
||||||
|
This code is provided as is under the Unlicense (https://unlicense.org/)
|
||||||
|
Although you have no obligation to do so, if you found this code helpful please find it in your heart to credit me [parlance-zz].
|
||||||
|
|
||||||
|
Questions or comments can be sent to parlance@fifth-harmonic.com (https://github.com/parlance-zz/)
|
||||||
|
This code is part of a new branch of a discord bot I am working on integrating with diffusers (https://github.com/parlance-zz/g-diffuser-bot)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
global DEBUG_MODE
|
||||||
|
global TMP_ROOT_PATH
|
||||||
|
|
||||||
|
width = _np_src_image.shape[0]
|
||||||
|
height = _np_src_image.shape[1]
|
||||||
|
num_channels = _np_src_image.shape[2]
|
||||||
|
|
||||||
|
np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
|
||||||
|
np_mask_grey = (np.sum(np_mask_rgb, axis=2)/3.)
|
||||||
|
np_src_grey = (np.sum(np_src_image, axis=2)/3.)
|
||||||
|
all_mask = np.ones((width, height), dtype=bool)
|
||||||
|
img_mask = np_mask_grey > 1e-6
|
||||||
|
ref_mask = np_mask_grey < 1e-3
|
||||||
|
|
||||||
|
windowed_image = _np_src_image * (1.-self._get_masked_window_rgb(np_mask_grey))
|
||||||
|
windowed_image /= np.max(windowed_image)
|
||||||
|
windowed_image += np.average(_np_src_image) * np_mask_rgb# / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
|
||||||
|
#windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) / (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area
|
||||||
|
#_save_debug_img(windowed_image, "windowed_src_img")
|
||||||
|
|
||||||
|
src_fft = self._fft2(windowed_image) # get feature statistics from masked src img
|
||||||
|
src_dist = np.absolute(src_fft)
|
||||||
|
src_phase = src_fft / src_dist
|
||||||
|
#_save_debug_img(src_dist, "windowed_src_dist")
|
||||||
|
|
||||||
|
noise_window = self._get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
|
||||||
|
noise_rgb = np.random.random_sample((width, height, num_channels))
|
||||||
|
noise_grey = (np.sum(noise_rgb, axis=2)/3.)
|
||||||
|
noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
|
||||||
|
for c in range(num_channels):
|
||||||
|
noise_rgb[:,:,c] += (1. - color_variation) * noise_grey
|
||||||
|
|
||||||
|
noise_fft = self._fft2(noise_rgb)
|
||||||
|
for c in range(num_channels):
|
||||||
|
noise_fft[:,:,c] *= noise_window
|
||||||
|
noise_rgb = np.real(self._ifft2(noise_fft))
|
||||||
|
shaped_noise_fft = self._fft2(noise_rgb)
|
||||||
|
shaped_noise_fft[:,:,:] = np.absolute(shaped_noise_fft[:,:,:])**2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
|
||||||
|
|
||||||
|
brightness_variation = 0.#color_variation # todo: temporarily tieing brightness variation to color variation for now
|
||||||
|
contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
|
||||||
|
|
||||||
|
# scikit-image is used for histogram matching, very convenient!
|
||||||
|
shaped_noise = np.real(self._ifft2(shaped_noise_fft))
|
||||||
|
shaped_noise -= np.min(shaped_noise)
|
||||||
|
shaped_noise /= np.max(shaped_noise)
|
||||||
|
shaped_noise[img_mask,:] = skimage.exposure.match_histograms(shaped_noise[img_mask,:]**1., contrast_adjusted_np_src[ref_mask,:], channel_axis=1)
|
||||||
|
shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
|
||||||
|
#_save_debug_img(shaped_noise, "shaped_noise")
|
||||||
|
|
||||||
|
matched_noise = np.zeros((width, height, num_channels))
|
||||||
|
matched_noise = shaped_noise[:]
|
||||||
|
#matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:], _np_src_image[ref_mask,:], channel_axis=1)
|
||||||
|
#matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb
|
||||||
|
|
||||||
|
#_save_debug_img(matched_noise, "matched_noise")
|
||||||
|
|
||||||
|
"""
|
||||||
|
todo:
|
||||||
|
color_variation doesnt have to be a single number, the overall color tone of the out-painted area could be param controlled
|
||||||
|
"""
|
||||||
|
|
||||||
|
return np.clip(matched_noise, 0., 1.)
|
||||||
|
|
||||||
|
def find_noise_for_image(self, model, device, init_image, prompt, steps=200, cond_scale=2.0, verbose=False, normalize=False, generation_callback=None):
|
||||||
|
image = np.array(init_image).astype(np.float32) / 255.0
|
||||||
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
image = 2. * image - 1.
|
||||||
|
image = image.to(device)
|
||||||
|
x = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||||
|
|
||||||
|
uncond = model.get_learned_conditioning([''])
|
||||||
|
cond = model.get_learned_conditioning([prompt])
|
||||||
|
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
dnw = K.external.CompVisDenoiser(model)
|
||||||
|
sigmas = dnw.get_sigmas(steps).flip(0)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(sigmas)
|
||||||
|
|
||||||
|
for i in tqdm.trange(1, len(sigmas)):
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
|
||||||
|
cond_in = torch.cat([uncond, cond])
|
||||||
|
|
||||||
|
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
|
||||||
|
|
||||||
|
if i == 1:
|
||||||
|
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
|
||||||
|
else:
|
||||||
|
t = dnw.sigma_to_t(sigma_in)
|
||||||
|
|
||||||
|
eps = model.apply_model(x_in * c_in, t, cond=cond_in)
|
||||||
|
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
|
||||||
|
|
||||||
|
denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale
|
||||||
|
|
||||||
|
if i == 1:
|
||||||
|
d = (x - denoised) / (2 * sigmas[i])
|
||||||
|
else:
|
||||||
|
d = (x - denoised) / sigmas[i - 1]
|
||||||
|
|
||||||
|
dt = sigmas[i] - sigmas[i - 1]
|
||||||
|
x = x + d * dt
|
||||||
|
|
||||||
|
return x / sigmas[-1]
|
||||||
|
|
||||||
|
def generate(self, prompt: str, init_img=None, init_mask=None, mask_mode='mask', resize_mode='resize', noise_mode='seed',
|
||||||
|
denoising_strength:float=0.8, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None,
|
||||||
|
height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0):
|
||||||
|
seed = seed_to_int(seed)
|
||||||
|
image_dict = {
|
||||||
|
"seed": seed
|
||||||
|
}
|
||||||
|
# Init image is assumed to be a PIL image
|
||||||
|
init_img = self.resize_image('resize', init_img, width, height)
|
||||||
|
if sampler_name == 'PLMS':
|
||||||
|
sampler = PLMSSampler(self.model)
|
||||||
|
elif sampler_name == 'DDIM':
|
||||||
|
sampler = DDIMSampler(self.model)
|
||||||
|
elif sampler_name == 'k_dpm_2_a':
|
||||||
|
sampler = KDiffusionSampler(self.model,'dpm_2_ancestral')
|
||||||
|
elif sampler_name == 'k_dpm_2':
|
||||||
|
sampler = KDiffusionSampler(self.model,'dpm_2')
|
||||||
|
elif sampler_name == 'k_euler_a':
|
||||||
|
sampler = KDiffusionSampler(self.model,'euler_ancestral')
|
||||||
|
elif sampler_name == 'k_euler':
|
||||||
|
sampler = KDiffusionSampler(self.model,'euler')
|
||||||
|
elif sampler_name == 'k_heun':
|
||||||
|
sampler = KDiffusionSampler(self.model,'heun')
|
||||||
|
elif sampler_name == 'k_lms':
|
||||||
|
sampler = KDiffusionSampler(self.model,'lms')
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown sampler: " + sampler_name)
|
||||||
|
|
||||||
|
torch_gc()
|
||||||
|
def process_init_mask(init_mask: PIL.Image):
|
||||||
|
if init_mask.mode == "RGBA":
|
||||||
|
init_mask = init_mask.convert('RGBA')
|
||||||
|
background = PIL.Image.new('RGBA', init_mask.size, (0, 0, 0))
|
||||||
|
init_mask = PIL.Image.alpha_composite(background, init_mask)
|
||||||
|
init_mask = init_mask.convert('RGB')
|
||||||
|
return init_mask
|
||||||
|
|
||||||
|
if mask_mode == "mask":
|
||||||
|
if init_mask:
|
||||||
|
init_mask = process_init_mask(init_mask)
|
||||||
|
elif mask_mode == "invert":
|
||||||
|
if init_mask:
|
||||||
|
init_mask = process_init_mask(init_mask)
|
||||||
|
init_mask = PIL.ImageOps.invert(init_mask)
|
||||||
|
elif mask_mode == "alpha":
|
||||||
|
init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||||
|
init_mask = init_img_transparency
|
||||||
|
init_mask = init_mask.convert("RGB")
|
||||||
|
init_mask = self.resize_image(resize_mode, init_mask, width, height)
|
||||||
|
init_mask = init_mask.convert("RGB")
|
||||||
|
|
||||||
|
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
t_enc = int(denoising_strength * ddim_steps)
|
||||||
|
|
||||||
|
if init_mask is not None and (noise_mode == "matched" or noise_mode == "find_and_matched") and init_img is not None:
|
||||||
|
noise_q = 0.99
|
||||||
|
color_variation = 0.0
|
||||||
|
mask_blend_factor = 1.0
|
||||||
|
|
||||||
|
np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
|
||||||
|
np_mask_rgb = 1. - (np.asarray(PIL.ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
|
||||||
|
np_mask_rgb -= np.min(np_mask_rgb)
|
||||||
|
np_mask_rgb /= np.max(np_mask_rgb)
|
||||||
|
np_mask_rgb = 1. - np_mask_rgb
|
||||||
|
np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
|
||||||
|
blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
|
||||||
|
blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
|
||||||
|
#np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
|
||||||
|
#np_mask_rgb = np_mask_rgb + blurred
|
||||||
|
np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
|
||||||
|
np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
|
||||||
|
|
||||||
|
noise_rgb = self.get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
|
||||||
|
blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
|
||||||
|
noised = noise_rgb[:]
|
||||||
|
blend_mask_rgb **= (2.)
|
||||||
|
noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
|
||||||
|
|
||||||
|
np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
|
||||||
|
ref_mask = np_mask_grey < 1e-3
|
||||||
|
|
||||||
|
all_mask = np.ones((height, width), dtype=bool)
|
||||||
|
noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
|
||||||
|
|
||||||
|
init_img = PIL.Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
|
||||||
|
|
||||||
|
def init():
|
||||||
|
image = init_img.convert('RGB')
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
|
||||||
|
mask_channel = None
|
||||||
|
if init_mask:
|
||||||
|
alpha = self.resize_image(resize_mode, init_mask, width // 8, height // 8)
|
||||||
|
mask_channel = alpha.split()[-1]
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if mask_channel is not None:
|
||||||
|
mask = np.array(mask_channel).astype(np.float32) / 255.0
|
||||||
|
mask = (1 - mask)
|
||||||
|
mask = np.tile(mask, (4, 1, 1))
|
||||||
|
mask = mask[None].transpose(0, 1, 2, 3)
|
||||||
|
mask = torch.from_numpy(mask).to(self.model.device)
|
||||||
|
|
||||||
|
init_image = 2. * image - 1.
|
||||||
|
init_image = init_image.to(self.model.device)
|
||||||
|
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
|
return init_latent, mask,
|
||||||
|
|
||||||
|
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||||
|
t_enc_steps = t_enc
|
||||||
|
obliterate = False
|
||||||
|
if ddim_steps == t_enc_steps:
|
||||||
|
t_enc_steps = t_enc_steps - 1
|
||||||
|
obliterate = True
|
||||||
|
|
||||||
|
if sampler_name != 'DDIM':
|
||||||
|
x0, z_mask = init_data
|
||||||
|
|
||||||
|
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
|
||||||
|
noise = x * sigmas[ddim_steps - t_enc_steps - 1]
|
||||||
|
|
||||||
|
xi = x0 + noise
|
||||||
|
|
||||||
|
# Obliterate masked image
|
||||||
|
if z_mask is not None and obliterate:
|
||||||
|
random = torch.randn(z_mask.shape, device=xi.device)
|
||||||
|
xi = (z_mask * noise) + ((1-z_mask) * xi)
|
||||||
|
|
||||||
|
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
|
||||||
|
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
|
||||||
|
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
|
||||||
|
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
|
||||||
|
else:
|
||||||
|
|
||||||
|
x0, z_mask = init_data
|
||||||
|
|
||||||
|
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
|
||||||
|
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(self.model.device))
|
||||||
|
|
||||||
|
# Obliterate masked image
|
||||||
|
if z_mask is not None and obliterate:
|
||||||
|
random = torch.randn(z_mask.shape, device=z_enc.device)
|
||||||
|
z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
|
||||||
|
|
||||||
|
# decode it
|
||||||
|
samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
|
||||||
|
unconditional_guidance_scale=cfg_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
z_mask=z_mask, x0=x0)
|
||||||
|
return samples_ddim
|
||||||
|
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
if self.load_concepts and self.concepts_dir is not None:
|
||||||
|
prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
|
||||||
|
if prompt_tokens:
|
||||||
|
self.process_prompt_tokens(prompt_tokens)
|
||||||
|
|
||||||
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
sample_path = os.path.join(self.output_dir, "samples")
|
||||||
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
|
||||||
|
if self.verify_input:
|
||||||
|
try:
|
||||||
|
check_prompt_length(self.model, prompt, self.comments)
|
||||||
|
except:
|
||||||
|
import traceback
|
||||||
|
print("Error verifying input:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
all_prompts = batch_size * n_iter * [prompt]
|
||||||
|
all_seeds = [seed + x for x in range(len(all_prompts))]
|
||||||
|
|
||||||
|
precision_scope = torch.autocast if self.auto_cast else nullcontext
|
||||||
|
|
||||||
|
with torch.no_grad(), precision_scope("cuda"):
|
||||||
|
for n in range(n_iter):
|
||||||
|
print(f"Iteration: {n+1}/{n_iter}")
|
||||||
|
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
|
||||||
|
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
|
||||||
|
|
||||||
|
uc = self.model.get_learned_conditioning(len(prompts) * [''])
|
||||||
|
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
|
||||||
|
c = self.model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
|
opt_C = 4
|
||||||
|
opt_f = 8
|
||||||
|
shape = [opt_C, height // opt_f, width // opt_f]
|
||||||
|
|
||||||
|
x = self.create_random_tensors(shape, seeds=seeds)
|
||||||
|
init_data = init()
|
||||||
|
samples_ddim = sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
|
||||||
|
|
||||||
|
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
|
||||||
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
|
sanitized_prompt = slugify(prompts[i])
|
||||||
|
full_path = os.path.join(os.getcwd(), sample_path)
|
||||||
|
sample_path_i = sample_path
|
||||||
|
base_count = get_next_sequence_number(sample_path_i)
|
||||||
|
filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)]
|
||||||
|
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
image = PIL.Image.fromarray(x_sample)
|
||||||
|
image_dict['image'] = image
|
||||||
|
self.images.append(image_dict)
|
||||||
|
|
||||||
|
if save_individual_images:
|
||||||
|
path = os.path.join(sample_path, filename + '.' + self.save_extension)
|
||||||
|
success = save_sample(image, filename, sample_path_i, self.save_extension)
|
||||||
|
if success:
|
||||||
|
if self.output_file_path:
|
||||||
|
self.output_images.append(path)
|
||||||
|
else:
|
||||||
|
self.output_images.append(image)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.info = f"""
|
||||||
|
{prompt}
|
||||||
|
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}
|
||||||
|
""".strip()
|
||||||
|
self.stats = f'''
|
||||||
|
'''
|
||||||
|
|
||||||
|
for comment in self.comments:
|
||||||
|
self.info += "\n\n" + comment
|
||||||
|
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
del sampler
|
||||||
|
|
||||||
|
return
|
201
scripts/nataili/inference/compvis/txt2img.py
Normal file
201
scripts/nataili/inference/compvis/txt2img.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.kdiffusion import KDiffusionSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from nataili.util.cache import torch_gc
|
||||||
|
from nataili.util.check_prompt_length import check_prompt_length
|
||||||
|
from nataili.util.get_next_sequence_number import get_next_sequence_number
|
||||||
|
from nataili.util.image_grid import image_grid
|
||||||
|
from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip
|
||||||
|
from nataili.util.save_sample import save_sample
|
||||||
|
from nataili.util.seed_to_int import seed_to_int
|
||||||
|
from slugify import slugify
|
||||||
|
|
||||||
|
|
||||||
|
class txt2img:
|
||||||
|
def __init__(self, model, device, output_dir, save_extension='jpg',
|
||||||
|
output_file_path=False, load_concepts=False, concepts_dir=None,
|
||||||
|
verify_input=True, auto_cast=True):
|
||||||
|
self.model = model
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.output_file_path = output_file_path
|
||||||
|
self.save_extension = save_extension
|
||||||
|
self.load_concepts = load_concepts
|
||||||
|
self.concepts_dir = concepts_dir
|
||||||
|
self.verify_input = verify_input
|
||||||
|
self.auto_cast = auto_cast
|
||||||
|
self.device = device
|
||||||
|
self.comments = []
|
||||||
|
self.output_images = []
|
||||||
|
self.info = ''
|
||||||
|
self.stats = ''
|
||||||
|
self.images = []
|
||||||
|
|
||||||
|
def create_random_tensors(self, shape, seeds):
|
||||||
|
xs = []
|
||||||
|
for seed in seeds:
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
# randn results depend on device; gpu and cpu get different results for same seed;
|
||||||
|
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
||||||
|
# but the original script had it like this so i do not dare change it for now because
|
||||||
|
# it will break everyone's seeds.
|
||||||
|
xs.append(torch.randn(shape, device=self.device))
|
||||||
|
x = torch.stack(xs)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def process_prompt_tokens(self, prompt_tokens):
|
||||||
|
# compviz codebase
|
||||||
|
tokenizer = self.model.cond_stage_model.tokenizer
|
||||||
|
text_encoder = self.model.cond_stage_model.transformer
|
||||||
|
|
||||||
|
# diffusers codebase
|
||||||
|
#tokenizer = pipe.tokenizer
|
||||||
|
#text_encoder = pipe.text_encoder
|
||||||
|
|
||||||
|
ext = ('.pt', '.bin')
|
||||||
|
for token_name in prompt_tokens:
|
||||||
|
embedding_path = os.path.join(self.concepts_dir, token_name)
|
||||||
|
if os.path.exists(embedding_path):
|
||||||
|
for files in os.listdir(embedding_path):
|
||||||
|
if files.endswith(ext):
|
||||||
|
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
|
||||||
|
else:
|
||||||
|
print(f"Concept {token_name} not found in {self.concepts_dir}")
|
||||||
|
del tokenizer, text_encoder
|
||||||
|
return
|
||||||
|
del tokenizer, text_encoder
|
||||||
|
|
||||||
|
def generate(self, prompt: str, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None,
|
||||||
|
height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0):
|
||||||
|
seed = seed_to_int(seed)
|
||||||
|
|
||||||
|
image_dict = {
|
||||||
|
"seed": seed
|
||||||
|
}
|
||||||
|
negprompt = ''
|
||||||
|
if '###' in prompt:
|
||||||
|
prompt, negprompt = prompt.split('###', 1)
|
||||||
|
prompt = prompt.strip()
|
||||||
|
negprompt = negprompt.strip()
|
||||||
|
|
||||||
|
if sampler_name == 'PLMS':
|
||||||
|
sampler = PLMSSampler(self.model)
|
||||||
|
elif sampler_name == 'DDIM':
|
||||||
|
sampler = DDIMSampler(self.model)
|
||||||
|
elif sampler_name == 'k_dpm_2_a':
|
||||||
|
sampler = KDiffusionSampler(self.model,'dpm_2_ancestral')
|
||||||
|
elif sampler_name == 'k_dpm_2':
|
||||||
|
sampler = KDiffusionSampler(self.model,'dpm_2')
|
||||||
|
elif sampler_name == 'k_euler_a':
|
||||||
|
sampler = KDiffusionSampler(self.model,'euler_ancestral')
|
||||||
|
elif sampler_name == 'k_euler':
|
||||||
|
sampler = KDiffusionSampler(self.model,'euler')
|
||||||
|
elif sampler_name == 'k_heun':
|
||||||
|
sampler = KDiffusionSampler(self.model,'heun')
|
||||||
|
elif sampler_name == 'k_lms':
|
||||||
|
sampler = KDiffusionSampler(self.model,'lms')
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown sampler: " + sampler_name)
|
||||||
|
|
||||||
|
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||||
|
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, unconditional_guidance_scale=cfg_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning, x_T=x)
|
||||||
|
return samples_ddim
|
||||||
|
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
if self.load_concepts and self.concepts_dir is not None:
|
||||||
|
prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
|
||||||
|
if prompt_tokens:
|
||||||
|
self.process_prompt_tokens(prompt_tokens)
|
||||||
|
|
||||||
|
os.makedirs(self.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
sample_path = os.path.join(self.output_dir, "samples")
|
||||||
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
|
||||||
|
if self.verify_input:
|
||||||
|
try:
|
||||||
|
check_prompt_length(self.model, prompt, self.comments)
|
||||||
|
except:
|
||||||
|
import traceback
|
||||||
|
print("Error verifying input:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
all_prompts = batch_size * n_iter * [prompt]
|
||||||
|
all_seeds = [seed + x for x in range(len(all_prompts))]
|
||||||
|
|
||||||
|
precision_scope = torch.autocast if self.auto_cast else nullcontext
|
||||||
|
|
||||||
|
with torch.no_grad(), precision_scope("cuda"):
|
||||||
|
for n in range(n_iter):
|
||||||
|
print(f"Iteration: {n+1}/{n_iter}")
|
||||||
|
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
|
||||||
|
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
|
||||||
|
|
||||||
|
uc = self.model.get_learned_conditioning(len(prompts) * [negprompt])
|
||||||
|
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
|
||||||
|
c = self.model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
|
opt_C = 4
|
||||||
|
opt_f = 8
|
||||||
|
shape = [opt_C, height // opt_f, width // opt_f]
|
||||||
|
|
||||||
|
x = self.create_random_tensors(shape, seeds=seeds)
|
||||||
|
|
||||||
|
samples_ddim = sample(init_data=None, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
|
||||||
|
|
||||||
|
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
|
||||||
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
|
sanitized_prompt = slugify(prompts[i])
|
||||||
|
full_path = os.path.join(os.getcwd(), sample_path)
|
||||||
|
sample_path_i = sample_path
|
||||||
|
base_count = get_next_sequence_number(sample_path_i)
|
||||||
|
filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)]
|
||||||
|
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
image = PIL.Image.fromarray(x_sample)
|
||||||
|
image_dict['image'] = image
|
||||||
|
self.images.append(image_dict)
|
||||||
|
|
||||||
|
if save_individual_images:
|
||||||
|
path = os.path.join(sample_path, filename + '.' + self.save_extension)
|
||||||
|
success = save_sample(image, filename, sample_path_i, self.save_extension)
|
||||||
|
if success:
|
||||||
|
if self.output_file_path:
|
||||||
|
self.output_images.append(path)
|
||||||
|
else:
|
||||||
|
self.output_images.append(image)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.info = f"""
|
||||||
|
{prompt}
|
||||||
|
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}
|
||||||
|
""".strip()
|
||||||
|
self.stats = f'''
|
||||||
|
'''
|
||||||
|
|
||||||
|
for comment in self.comments:
|
||||||
|
self.info += "\n\n" + comment
|
||||||
|
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
del sampler
|
||||||
|
|
||||||
|
return
|
0
scripts/nataili/inference/diffusers/__init__.py
Normal file
0
scripts/nataili/inference/diffusers/__init__.py
Normal file
458
scripts/nataili/model_manager.py
Normal file
458
scripts/nataili/model_manager.py
Normal file
@ -0,0 +1,458 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import zipfile
|
||||||
|
import requests
|
||||||
|
import git
|
||||||
|
import torch
|
||||||
|
import hashlib
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from transformers import logging
|
||||||
|
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from gfpgan import GFPGANer
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
from ldm.models.blip import blip_decoder
|
||||||
|
from tqdm import tqdm
|
||||||
|
import open_clip
|
||||||
|
import clip
|
||||||
|
|
||||||
|
from nataili.util.cache import torch_gc
|
||||||
|
from nataili.util import logger
|
||||||
|
|
||||||
|
logging.set_verbosity_error()
|
||||||
|
|
||||||
|
models = json.load(open('./db.json'))
|
||||||
|
dependencies = json.load(open('./db_dep.json'))
|
||||||
|
remote_models = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db.json"
|
||||||
|
remote_dependencies = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db_dep.json"
|
||||||
|
|
||||||
|
class ModelManager():
|
||||||
|
def __init__(self, hf_auth=None, download=True):
|
||||||
|
if download:
|
||||||
|
try:
|
||||||
|
logger.init("Model Reference", status="Downloading")
|
||||||
|
r = requests.get(remote_models)
|
||||||
|
self.models = r.json()
|
||||||
|
r = requests.get(remote_dependencies)
|
||||||
|
self.dependencies = json.load(open('./db_dep.json'))
|
||||||
|
logger.init_ok("Model Reference", status="OK")
|
||||||
|
except:
|
||||||
|
logger.init_err("Model Reference", status="Download Error")
|
||||||
|
self.models = json.load(open('./db.json'))
|
||||||
|
self.dependencies = json.load(open('./db_dep.json'))
|
||||||
|
logger.init_warn("Model Reference", status="Local")
|
||||||
|
self.available_models = []
|
||||||
|
self.tainted_models = []
|
||||||
|
self.available_dependencies = []
|
||||||
|
self.loaded_models = {}
|
||||||
|
self.hf_auth = None
|
||||||
|
self.set_authentication(hf_auth)
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
dependencies_available = []
|
||||||
|
for dependency in self.dependencies:
|
||||||
|
if self.check_available(self.get_dependency_files(dependency)):
|
||||||
|
dependencies_available.append(dependency)
|
||||||
|
self.available_dependencies = dependencies_available
|
||||||
|
|
||||||
|
models_available = []
|
||||||
|
for model in self.models:
|
||||||
|
if self.check_available(self.get_model_files(model)):
|
||||||
|
models_available.append(model)
|
||||||
|
self.available_models = models_available
|
||||||
|
|
||||||
|
if self.hf_auth is not None:
|
||||||
|
if 'username' not in self.hf_auth and 'password' not in self.hf_auth:
|
||||||
|
raise ValueError('hf_auth must contain username and password')
|
||||||
|
else:
|
||||||
|
if self.hf_auth['username'] == '' or self.hf_auth['password'] == '':
|
||||||
|
raise ValueError('hf_auth must contain username and password')
|
||||||
|
return True
|
||||||
|
|
||||||
|
def set_authentication(self, hf_auth=None):
|
||||||
|
# We do not let No authentication override previously set auth
|
||||||
|
if not hf_auth and self.hf_auth:
|
||||||
|
return
|
||||||
|
self.hf_auth = hf_auth
|
||||||
|
|
||||||
|
def get_model(self, model_name):
|
||||||
|
return self.models.get(model_name)
|
||||||
|
|
||||||
|
def get_filtered_models(self, **kwargs):
|
||||||
|
'''Get all model names.
|
||||||
|
Can filter based on metadata of the model reference db
|
||||||
|
'''
|
||||||
|
filtered_models = self.models
|
||||||
|
for keyword in kwargs:
|
||||||
|
iterating_models = filtered_models.copy()
|
||||||
|
filtered_models = {}
|
||||||
|
for model in iterating_models:
|
||||||
|
# logger.debug([keyword,iterating_models[model].get(keyword),kwargs[keyword]])
|
||||||
|
if iterating_models[model].get(keyword) == kwargs[keyword]:
|
||||||
|
filtered_models[model] = iterating_models[model]
|
||||||
|
return filtered_models
|
||||||
|
|
||||||
|
def get_filtered_model_names(self, **kwargs):
|
||||||
|
filtered_models = self.get_filtered_models(**kwargs)
|
||||||
|
return list(filtered_models.keys())
|
||||||
|
|
||||||
|
def get_dependency(self, dependency_name):
|
||||||
|
return self.dependencies[dependency_name]
|
||||||
|
|
||||||
|
def get_model_files(self, model_name):
|
||||||
|
return self.models[model_name]['config']['files']
|
||||||
|
|
||||||
|
def get_dependency_files(self, dependency_name):
|
||||||
|
return self.dependencies[dependency_name]['config']['files']
|
||||||
|
|
||||||
|
def get_model_download(self, model_name):
|
||||||
|
return self.models[model_name]['config']['download']
|
||||||
|
|
||||||
|
def get_dependency_download(self, dependency_name):
|
||||||
|
return self.dependencies[dependency_name]['config']['download']
|
||||||
|
|
||||||
|
def get_available_models(self):
|
||||||
|
return self.available_models
|
||||||
|
|
||||||
|
def get_available_dependencies(self):
|
||||||
|
return self.available_dependencies
|
||||||
|
|
||||||
|
def get_loaded_models(self):
|
||||||
|
return self.loaded_models
|
||||||
|
|
||||||
|
def get_loaded_models_names(self):
|
||||||
|
return list(self.loaded_models.keys())
|
||||||
|
|
||||||
|
def get_loaded_model(self, model_name):
|
||||||
|
return self.loaded_models[model_name]
|
||||||
|
|
||||||
|
def unload_model(self, model_name):
|
||||||
|
if model_name in self.loaded_models:
|
||||||
|
del self.loaded_models[model_name]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unload_all_models(self):
|
||||||
|
for model in self.loaded_models:
|
||||||
|
del self.loaded_models[model]
|
||||||
|
return True
|
||||||
|
|
||||||
|
def taint_model(self,model_name):
|
||||||
|
'''Marks a model as not valid by remiving it from available_models'''
|
||||||
|
if model_name in self.available_models:
|
||||||
|
self.available_models.remove(model_name)
|
||||||
|
self.tainted_models.append(model_name)
|
||||||
|
|
||||||
|
def taint_models(self, models):
|
||||||
|
for model in models:
|
||||||
|
self.taint_model(model)
|
||||||
|
|
||||||
|
def load_model_from_config(self, model_path='', config_path='', map_location="cpu"):
|
||||||
|
config = OmegaConf.load(config_path)
|
||||||
|
pl_sd = torch.load(model_path, map_location=map_location)
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
logger.info(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
model = model.eval()
|
||||||
|
del pl_sd, sd, m, u
|
||||||
|
return model
|
||||||
|
|
||||||
|
def load_ckpt(self, model_name='', precision='half', gpu_id=0):
|
||||||
|
ckpt_path = self.get_model_files(model_name)[0]['path']
|
||||||
|
config_path = self.get_model_files(model_name)[1]['path']
|
||||||
|
model = self.load_model_from_config(model_path=ckpt_path, config_path=config_path)
|
||||||
|
device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
model = (model if precision=='full' else model.half()).to(device)
|
||||||
|
torch_gc()
|
||||||
|
return {'model': model, 'device': device}
|
||||||
|
|
||||||
|
def load_realesrgan(self, model_name='', precision='half', gpu_id=0):
|
||||||
|
|
||||||
|
RealESRGAN_models = {
|
||||||
|
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
|
||||||
|
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||||
|
}
|
||||||
|
|
||||||
|
model_path = self.get_model_files(model_name)[0]['path']
|
||||||
|
device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
model = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[models[model_name]['name']],
|
||||||
|
pre_pad=0, half=True if precision == 'half' else False, device=device)
|
||||||
|
return {'model': model, 'device': device}
|
||||||
|
|
||||||
|
def load_gfpgan(self, model_name='', gpu_id=0):
|
||||||
|
|
||||||
|
model_path = self.get_model_files(model_name)[0]['path']
|
||||||
|
device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
model = GFPGANer(model_path=model_path, upscale=1, arch='clean',
|
||||||
|
channel_multiplier=2, bg_upsampler=None, device=device)
|
||||||
|
return {'model': model, 'device': device}
|
||||||
|
|
||||||
|
def load_blip(self, model_name='', precision='half', gpu_id=0, blip_image_eval_size=512, vit='base'):
|
||||||
|
# vit = 'base' or 'large'
|
||||||
|
model_path = self.get_model_files(model_name)[0]['path']
|
||||||
|
device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
model = blip_decoder(pretrained=model_path,
|
||||||
|
med_config="configs/blip/med_config.json",
|
||||||
|
image_size=blip_image_eval_size, vit=vit)
|
||||||
|
model = model.eval()
|
||||||
|
model = (model if precision=='full' else model.half()).to(device)
|
||||||
|
return {'model': model, 'device': device}
|
||||||
|
|
||||||
|
def load_open_clip(self, model_name='', precision='half', gpu_id=0):
|
||||||
|
pretrained = self.get_model(model_name)['pretrained_name']
|
||||||
|
device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
model, _, preprocesses = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, cache_dir='models/clip')
|
||||||
|
model = model.eval()
|
||||||
|
model = (model if precision=='full' else model.half()).to(device)
|
||||||
|
return {'model': model, 'device': device, 'preprocesses': preprocesses}
|
||||||
|
|
||||||
|
def load_clip(self, model_name='', precision='half', gpu_id=0):
|
||||||
|
device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
model, preprocesses = clip.load(model_name, device=device, download_root='models/clip')
|
||||||
|
model = model.eval()
|
||||||
|
model = (model if precision=='full' else model.half()).to(device)
|
||||||
|
return {'model': model, 'device': device, 'preprocesses': preprocesses}
|
||||||
|
|
||||||
|
def load_model(self, model_name='', precision='half', gpu_id=0):
|
||||||
|
if model_name not in self.available_models:
|
||||||
|
return False
|
||||||
|
if self.models[model_name]['type'] == 'ckpt':
|
||||||
|
self.loaded_models[model_name] = self.load_ckpt(model_name, precision, gpu_id)
|
||||||
|
return True
|
||||||
|
elif self.models[model_name]['type'] == 'realesrgan':
|
||||||
|
self.loaded_models[model_name] = self.load_realesrgan(model_name, precision, gpu_id)
|
||||||
|
return True
|
||||||
|
elif self.models[model_name]['type'] == 'gfpgan':
|
||||||
|
self.loaded_models[model_name] = self.load_gfpgan(model_name, gpu_id)
|
||||||
|
return True
|
||||||
|
elif self.models[model_name]['type'] == 'blip':
|
||||||
|
self.loaded_models[model_name] = self.load_blip(model_name, precision, gpu_id, 512, 'base')
|
||||||
|
return True
|
||||||
|
elif self.models[model_name]['type'] == 'open_clip':
|
||||||
|
self.loaded_models[model_name] = self.load_open_clip(model_name, precision, gpu_id)
|
||||||
|
return True
|
||||||
|
elif self.models[model_name]['type'] == 'clip':
|
||||||
|
self.loaded_models[model_name] = self.load_clip(model_name, precision, gpu_id)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def validate_model(self, model_name):
|
||||||
|
files = self.get_model_files(model_name)
|
||||||
|
all_ok = True
|
||||||
|
for file_details in files:
|
||||||
|
if not self.check_file_available(file_details['path']):
|
||||||
|
return False
|
||||||
|
if not self.validate_file(file_details):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_file(self, file_details):
|
||||||
|
if 'md5sum' in file_details:
|
||||||
|
file_name = file_details['path']
|
||||||
|
logger.debug(f"Getting md5sum of {file_name}")
|
||||||
|
with open(file_name, 'rb') as file_to_check:
|
||||||
|
file_hash = hashlib.md5()
|
||||||
|
while chunk := file_to_check.read(8192):
|
||||||
|
file_hash.update(chunk)
|
||||||
|
if file_details['md5sum'] != file_hash.hexdigest():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_file_available(self, file_path):
|
||||||
|
return os.path.exists(file_path)
|
||||||
|
|
||||||
|
def check_available(self, files):
|
||||||
|
available = True
|
||||||
|
for file in files:
|
||||||
|
if not self.check_file_available(file['path']):
|
||||||
|
available = False
|
||||||
|
return available
|
||||||
|
|
||||||
|
def download_file(self, url, file_path):
|
||||||
|
# make directory
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
pbar_desc = file_path.split('/')[-1]
|
||||||
|
r = requests.get(url, stream=True)
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
|
with tqdm(
|
||||||
|
# all optional kwargs
|
||||||
|
unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
|
||||||
|
desc=pbar_desc, total=int(r.headers.get('content-length', 0))
|
||||||
|
) as pbar:
|
||||||
|
for chunk in r.iter_content(chunk_size=16*1024):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
pbar.update(len(chunk))
|
||||||
|
|
||||||
|
def download_model(self, model_name):
|
||||||
|
if model_name in self.available_models:
|
||||||
|
logger.info(f"{model_name} is already available.")
|
||||||
|
return True
|
||||||
|
download = self.get_model_download(model_name)
|
||||||
|
files = self.get_model_files(model_name)
|
||||||
|
for i in range(len(download)):
|
||||||
|
file_path = f"{download[i]['file_path']}/{download[i]['file_name']}" if 'file_path' in download[i] else files[i]['path']
|
||||||
|
|
||||||
|
if 'file_url' in download[i]:
|
||||||
|
download_url = download[i]['file_url']
|
||||||
|
if 'hf_auth' in download[i]:
|
||||||
|
username = self.hf_auth['username']
|
||||||
|
password = self.hf_auth['password']
|
||||||
|
download_url = download_url.format(username=username, password=password)
|
||||||
|
if 'file_name' in download[i]:
|
||||||
|
download_name = download[i]['file_name']
|
||||||
|
if 'file_path' in download[i]:
|
||||||
|
download_path = download[i]['file_path']
|
||||||
|
|
||||||
|
if 'manual' in download[i]:
|
||||||
|
logger.warning(f"The model {model_name} requires manual download from {download_url}. Please place it in {download_path}/{download_name} then press ENTER to continue...")
|
||||||
|
input('')
|
||||||
|
continue
|
||||||
|
# TODO: simplify
|
||||||
|
if "file_content" in download[i]:
|
||||||
|
file_content = download[i]['file_content']
|
||||||
|
logger.info(f"writing {file_content} to {file_path}")
|
||||||
|
# make directory download_path
|
||||||
|
os.makedirs(download_path, exist_ok=True)
|
||||||
|
# write file_content to download_path/download_name
|
||||||
|
with open(os.path.join(download_path, download_name), 'w') as f:
|
||||||
|
f.write(file_content)
|
||||||
|
elif 'symlink' in download[i]:
|
||||||
|
logger.info(f"symlink {file_path} to {download[i]['symlink']}")
|
||||||
|
symlink = download[i]['symlink']
|
||||||
|
# make directory symlink
|
||||||
|
os.makedirs(download_path, exist_ok=True)
|
||||||
|
# make symlink from download_path/download_name to symlink
|
||||||
|
os.symlink(symlink, os.path.join(download_path, download_name))
|
||||||
|
elif 'git' in download[i]:
|
||||||
|
logger.info(f"git clone {download_url} to {file_path}")
|
||||||
|
# make directory download_path
|
||||||
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
git.Git(file_path).clone(download_url)
|
||||||
|
if 'post_process' in download[i]:
|
||||||
|
for post_process in download[i]['post_process']:
|
||||||
|
if 'delete' in post_process:
|
||||||
|
# delete folder post_process['delete']
|
||||||
|
logger.info(f"delete {post_process['delete']}")
|
||||||
|
try:
|
||||||
|
shutil.rmtree(post_process['delete'])
|
||||||
|
except PermissionError as e:
|
||||||
|
logger.error(f"[!] Something went wrong while deleting the `{post_process['delete']}`. Please delete it manually.")
|
||||||
|
logger.error("PermissionError: ", e)
|
||||||
|
else:
|
||||||
|
if not self.check_file_available(file_path) or model_name in self.tainted_models:
|
||||||
|
logger.debug(f'Downloading {download_url} to {file_path}')
|
||||||
|
self.download_file(download_url, file_path)
|
||||||
|
if not self.validate_model(model_name):
|
||||||
|
return False
|
||||||
|
if model_name in self.tainted_models:
|
||||||
|
self.tainted_models.remove(model_name)
|
||||||
|
self.init()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def download_dependency(self, dependency_name):
|
||||||
|
if dependency_name in self.available_dependencies:
|
||||||
|
logger.info(f"{dependency_name} is already installed.")
|
||||||
|
return True
|
||||||
|
download = self.get_dependency_download(dependency_name)
|
||||||
|
files = self.get_dependency_files(dependency_name)
|
||||||
|
for i in range(len(download)):
|
||||||
|
if "git" in download[i]:
|
||||||
|
logger.warning("git download not implemented yet")
|
||||||
|
break
|
||||||
|
|
||||||
|
file_path = files[i]['path']
|
||||||
|
if 'file_url' in download[i]:
|
||||||
|
download_url = download[i]['file_url']
|
||||||
|
if 'file_name' in download[i]:
|
||||||
|
download_name = download[i]['file_name']
|
||||||
|
if 'file_path' in download[i]:
|
||||||
|
download_path = download[i]['file_path']
|
||||||
|
logger.debug(download_name)
|
||||||
|
if "unzip" in download[i]:
|
||||||
|
zip_path = f'temp/{download_name}.zip'
|
||||||
|
# os dirname zip_path
|
||||||
|
# mkdir temp
|
||||||
|
os.makedirs("temp", exist_ok=True)
|
||||||
|
|
||||||
|
self.download_file(download_url, zip_path)
|
||||||
|
logger.info(f"unzip {zip_path}")
|
||||||
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||||
|
zip_ref.extractall('temp/')
|
||||||
|
# move temp/sd-concepts-library-main/sd-concepts-library to download_path
|
||||||
|
logger.info(f"move temp/{download_name}-main/{download_name} to {download_path}")
|
||||||
|
shutil.move(f"temp/{download_name}-main/{download_name}", download_path)
|
||||||
|
logger.info(f"delete {zip_path}")
|
||||||
|
os.remove(zip_path)
|
||||||
|
logger.info(f"delete temp/{download_name}-main/")
|
||||||
|
shutil.rmtree(f"temp/{download_name}-main")
|
||||||
|
else:
|
||||||
|
if not self.check_file_available(file_path):
|
||||||
|
logger.init(f'{file_path}', status="Downloading")
|
||||||
|
self.download_file(download_url, file_path)
|
||||||
|
self.init()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def download_all_models(self):
|
||||||
|
for model in self.get_filtered_model_names(download_all = True):
|
||||||
|
if not self.check_model_available(model):
|
||||||
|
logger.init(f"{model}", status="Downloading")
|
||||||
|
self.download_model(model)
|
||||||
|
else:
|
||||||
|
logger.info(f"{model} is already downloaded.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def download_all_dependencies(self):
|
||||||
|
for dependency in self.dependencies:
|
||||||
|
if not self.check_dependency_available(dependency):
|
||||||
|
logger.init(f"{dependency}",status="Downloading")
|
||||||
|
self.download_dependency(dependency)
|
||||||
|
else:
|
||||||
|
logger.info(f"{dependency} is already installed.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def download_all(self):
|
||||||
|
self.download_all_dependencies()
|
||||||
|
self.download_all_models()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_all_available(self):
|
||||||
|
for model in self.models:
|
||||||
|
if not self.check_available(self.get_model_files(model)):
|
||||||
|
return False
|
||||||
|
for dependency in self.dependencies:
|
||||||
|
if not self.check_available(self.get_dependency_files(dependency)):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_model_available(self, model_name):
|
||||||
|
if model_name not in self.models:
|
||||||
|
return False
|
||||||
|
return self.check_available(self.get_model_files(model_name))
|
||||||
|
|
||||||
|
def check_dependency_available(self, dependency_name):
|
||||||
|
if dependency_name not in self.dependencies:
|
||||||
|
return False
|
||||||
|
return self.check_available(self.get_dependency_files(dependency_name))
|
||||||
|
|
||||||
|
def check_all_available(self):
|
||||||
|
for model in self.models:
|
||||||
|
if not self.check_model_available(model):
|
||||||
|
return False
|
||||||
|
for dependency in self.dependencies:
|
||||||
|
if not self.check_dependency_available(dependency):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
0
scripts/nataili/postprocess/__init__.py
Normal file
0
scripts/nataili/postprocess/__init__.py
Normal file
48
scripts/nataili/postprocess/upscaler.py
Normal file
48
scripts/nataili/postprocess/upscaler.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Class realesrgan
|
||||||
|
# Inputs:
|
||||||
|
# - model
|
||||||
|
# - device
|
||||||
|
# - output_dir
|
||||||
|
# - output_ext
|
||||||
|
# outupts:
|
||||||
|
# - output_images
|
||||||
|
import PIL
|
||||||
|
from torchvision import transforms
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from nataili.util.save_sample import save_sample
|
||||||
|
|
||||||
|
class realesrgan:
|
||||||
|
def __init__(self, model, device, output_dir, output_ext='jpg'):
|
||||||
|
self.model = model
|
||||||
|
self.device = device
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.output_ext = output_ext
|
||||||
|
self.output_images = []
|
||||||
|
|
||||||
|
def generate(self, input_image):
|
||||||
|
# load image
|
||||||
|
img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
|
||||||
|
if len(img.shape) == 3 and img.shape[2] == 4:
|
||||||
|
img_mode = 'RGBA'
|
||||||
|
else:
|
||||||
|
img_mode = None
|
||||||
|
# upscale
|
||||||
|
output, _ = self.model.enhance(img)
|
||||||
|
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||||
|
self.output_ext = 'png'
|
||||||
|
|
||||||
|
esrgan_sample = output[:,:,::-1]
|
||||||
|
esrgan_image = PIL.Image.fromarray(esrgan_sample)
|
||||||
|
# append model name to output image name
|
||||||
|
filename = os.path.basename(input_image)
|
||||||
|
filename = os.path.splitext(filename)[0]
|
||||||
|
filename = f'{filename}_esrgan'
|
||||||
|
filename_with_ext = f'{filename}.{self.output_ext}'
|
||||||
|
output_image = os.path.join(self.output_dir, filename_with_ext)
|
||||||
|
save_sample(esrgan_image, filename, self.output_dir, self.output_ext)
|
||||||
|
self.output_images.append(output_image)
|
||||||
|
return
|
||||||
|
|
0
scripts/nataili/upscalers/__init__.py
Normal file
0
scripts/nataili/upscalers/__init__.py
Normal file
48
scripts/nataili/upscalers/realesrgan.py
Normal file
48
scripts/nataili/upscalers/realesrgan.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Class realesrgan
|
||||||
|
# Inputs:
|
||||||
|
# - model
|
||||||
|
# - device
|
||||||
|
# - output_dir
|
||||||
|
# - output_ext
|
||||||
|
# outupts:
|
||||||
|
# - output_images
|
||||||
|
import PIL
|
||||||
|
from torchvision import transforms
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from nataili.util.save_sample import save_sample
|
||||||
|
|
||||||
|
class realesrgan:
|
||||||
|
def __init__(self, model, device, output_dir, output_ext='jpg'):
|
||||||
|
self.model = model
|
||||||
|
self.device = device
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.output_ext = output_ext
|
||||||
|
self.output_images = []
|
||||||
|
|
||||||
|
def generate(self, input_image):
|
||||||
|
# load image
|
||||||
|
img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
|
||||||
|
if len(img.shape) == 3 and img.shape[2] == 4:
|
||||||
|
img_mode = 'RGBA'
|
||||||
|
else:
|
||||||
|
img_mode = None
|
||||||
|
# upscale
|
||||||
|
output, _ = self.model.enhance(img)
|
||||||
|
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||||
|
self.output_ext = 'png'
|
||||||
|
|
||||||
|
esrgan_sample = output[:,:,::-1]
|
||||||
|
esrgan_image = PIL.Image.fromarray(esrgan_sample)
|
||||||
|
# append model name to output image name
|
||||||
|
filename = os.path.basename(input_image)
|
||||||
|
filename = os.path.splitext(filename)[0]
|
||||||
|
filename = f'{filename}_esrgan'
|
||||||
|
filename_with_ext = f'{filename}.{self.output_ext}'
|
||||||
|
output_image = os.path.join(self.output_dir, filename_with_ext)
|
||||||
|
save_sample(esrgan_image, filename, self.output_dir, self.output_ext)
|
||||||
|
self.output_images.append(output_image)
|
||||||
|
return
|
||||||
|
|
1
scripts/nataili/util/__init__.py
Normal file
1
scripts/nataili/util/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from nataili.util.logger import logger,set_logger_verbosity, quiesce_logger, test_logger
|
16
scripts/nataili/util/cache.py
Normal file
16
scripts/nataili/util/cache.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import gc
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import threading
|
||||||
|
import pynvml
|
||||||
|
import time
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
def torch_gc():
|
||||||
|
for _ in range(2):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.reset_accumulated_memory_stats()
|
18
scripts/nataili/util/check_prompt_length.py
Normal file
18
scripts/nataili/util/check_prompt_length.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
def check_prompt_length(model, prompt, comments):
|
||||||
|
"""this function tests if prompt is too long, and if so, adds a message to comments"""
|
||||||
|
|
||||||
|
tokenizer = model.cond_stage_model.tokenizer
|
||||||
|
max_length = model.cond_stage_model.max_length
|
||||||
|
|
||||||
|
info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
|
||||||
|
return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
|
||||||
|
ovf = info['overflowing_tokens'][0]
|
||||||
|
overflowing_count = ovf.shape[0]
|
||||||
|
if overflowing_count == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
|
||||||
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
|
overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
del tokenizer
|
22
scripts/nataili/util/get_next_sequence_number.py
Normal file
22
scripts/nataili/util/get_next_sequence_number.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def get_next_sequence_number(path, prefix=''):
|
||||||
|
"""
|
||||||
|
Determines and returns the next sequence number to use when saving an
|
||||||
|
image in the specified directory.
|
||||||
|
|
||||||
|
If a prefix is given, only consider files whose names start with that
|
||||||
|
prefix, and strip the prefix from filenames before extracting their
|
||||||
|
sequence number.
|
||||||
|
|
||||||
|
The sequence starts at 0.
|
||||||
|
"""
|
||||||
|
result = -1
|
||||||
|
for p in Path(path).iterdir():
|
||||||
|
if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix):
|
||||||
|
tmp = p.name[len(prefix):]
|
||||||
|
try:
|
||||||
|
result = max(int(tmp.split('-')[0]), result)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return result + 1
|
21
scripts/nataili/util/image_grid.py
Normal file
21
scripts/nataili/util/image_grid.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
|
def image_grid(imgs, n_rows=None):
|
||||||
|
if n_rows is not None:
|
||||||
|
rows = n_rows
|
||||||
|
else:
|
||||||
|
rows = math.sqrt(len(imgs))
|
||||||
|
rows = round(rows)
|
||||||
|
|
||||||
|
cols = math.ceil(len(imgs) / rows)
|
||||||
|
|
||||||
|
w, h = imgs[0].size
|
||||||
|
grid = PIL.Image.new('RGB', size=(cols * w, rows * h), color='black')
|
||||||
|
|
||||||
|
for i, img in enumerate(imgs):
|
||||||
|
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||||
|
|
||||||
|
return grid
|
40
scripts/nataili/util/load_learned_embed_in_clip.py
Normal file
40
scripts/nataili/util/load_learned_embed_in_clip.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
|
||||||
|
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
||||||
|
# separate token and the embeds
|
||||||
|
if learned_embeds_path.endswith('.pt'):
|
||||||
|
# old format
|
||||||
|
# token = * so replace with file directory name when converting
|
||||||
|
trained_token = os.path.basename(learned_embeds_path)
|
||||||
|
params_dict = {
|
||||||
|
trained_token: torch.tensor(list(loaded_learned_embeds['string_to_param'].items())[0][1])
|
||||||
|
}
|
||||||
|
learned_embeds_path = os.path.splitext(learned_embeds_path)[0] + '.bin'
|
||||||
|
torch.save(params_dict, learned_embeds_path)
|
||||||
|
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
||||||
|
trained_token = list(loaded_learned_embeds.keys())[0]
|
||||||
|
embeds = loaded_learned_embeds[trained_token]
|
||||||
|
elif learned_embeds_path.endswith('.bin'):
|
||||||
|
trained_token = list(loaded_learned_embeds.keys())[0]
|
||||||
|
embeds = loaded_learned_embeds[trained_token]
|
||||||
|
|
||||||
|
embeds = loaded_learned_embeds[trained_token]
|
||||||
|
# cast to dtype of text_encoder
|
||||||
|
dtype = text_encoder.get_input_embeddings().weight.dtype
|
||||||
|
embeds.to(dtype)
|
||||||
|
|
||||||
|
# add the token in tokenizer
|
||||||
|
token = token if token is not None else trained_token
|
||||||
|
num_added_tokens = tokenizer.add_tokens(token)
|
||||||
|
|
||||||
|
# resize the token embeddings
|
||||||
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
# get the id for the token and assign the embeds
|
||||||
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||||
|
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
||||||
|
return token
|
102
scripts/nataili/util/logger.py
Normal file
102
scripts/nataili/util/logger.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import sys
|
||||||
|
from functools import partialmethod
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
STDOUT_LEVELS = ["GENERATION", "PROMPT"]
|
||||||
|
INIT_LEVELS = ["INIT", "INIT_OK", "INIT_WARN", "INIT_ERR"]
|
||||||
|
MESSAGE_LEVELS = ["MESSAGE"]
|
||||||
|
# By default we're at error level or higher
|
||||||
|
verbosity = 20
|
||||||
|
quiet = 0
|
||||||
|
|
||||||
|
def set_logger_verbosity(count):
|
||||||
|
global verbosity
|
||||||
|
# The count comes reversed. So count = 0 means minimum verbosity
|
||||||
|
# While count 5 means maximum verbosity
|
||||||
|
# So the more count we have, the lowe we drop the versbosity maximum
|
||||||
|
verbosity = 20 - (count * 10)
|
||||||
|
|
||||||
|
def quiesce_logger(count):
|
||||||
|
global quiet
|
||||||
|
# The bigger the count, the more silent we want our logger
|
||||||
|
quiet = count * 10
|
||||||
|
|
||||||
|
def is_stdout_log(record):
|
||||||
|
if record["level"].name not in STDOUT_LEVELS:
|
||||||
|
return(False)
|
||||||
|
if record["level"].no < verbosity + quiet:
|
||||||
|
return(False)
|
||||||
|
return(True)
|
||||||
|
|
||||||
|
def is_init_log(record):
|
||||||
|
if record["level"].name not in INIT_LEVELS:
|
||||||
|
return(False)
|
||||||
|
if record["level"].no < verbosity + quiet:
|
||||||
|
return(False)
|
||||||
|
return(True)
|
||||||
|
|
||||||
|
def is_msg_log(record):
|
||||||
|
if record["level"].name not in MESSAGE_LEVELS:
|
||||||
|
return(False)
|
||||||
|
if record["level"].no < verbosity + quiet:
|
||||||
|
return(False)
|
||||||
|
return(True)
|
||||||
|
|
||||||
|
def is_stderr_log(record):
|
||||||
|
if record["level"].name in STDOUT_LEVELS + INIT_LEVELS + MESSAGE_LEVELS:
|
||||||
|
return(False)
|
||||||
|
if record["level"].no < verbosity + quiet:
|
||||||
|
return(False)
|
||||||
|
return(True)
|
||||||
|
|
||||||
|
def test_logger():
|
||||||
|
logger.generation("This is a generation message\nIt is typically multiline\nThee Lines".encode("unicode_escape").decode("utf-8"))
|
||||||
|
logger.prompt("This is a prompt message")
|
||||||
|
logger.debug("Debug Message")
|
||||||
|
logger.info("Info Message")
|
||||||
|
logger.warning("Info Warning")
|
||||||
|
logger.error("Error Message")
|
||||||
|
logger.critical("Critical Message")
|
||||||
|
logger.init("This is an init message", status="Starting")
|
||||||
|
logger.init_ok("This is an init message", status="OK")
|
||||||
|
logger.init_warn("This is an init message", status="Warning")
|
||||||
|
logger.init_err("This is an init message", status="Error")
|
||||||
|
logger.message("This is user message")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
|
||||||
|
logfmt = "<level>{level: <10}</level> | <green>{time:YYYY-MM-DD HH:mm:ss}</green> | <green>{name}</green>:<green>{function}</green>:<green>{line}</green> - <level>{message}</level>"
|
||||||
|
genfmt = "<level>{level: <10}</level> @ <green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{message}</level>"
|
||||||
|
initfmt = "<magenta>INIT </magenta> | <level>{extra[status]: <11}</level> | <magenta>{message}</magenta>"
|
||||||
|
msgfmt = "<level>{level: <10}</level> | <level>{message}</level>"
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.level("GENERATION", no=24, color="<cyan>")
|
||||||
|
logger.level("PROMPT", no=23, color="<yellow>")
|
||||||
|
logger.level("INIT", no=31, color="<white>")
|
||||||
|
logger.level("INIT_OK", no=31, color="<green>")
|
||||||
|
logger.level("INIT_WARN", no=31, color="<yellow>")
|
||||||
|
logger.level("INIT_ERR", no=31, color="<red>")
|
||||||
|
# Messages contain important information without which this application might not be able to be used
|
||||||
|
# As such, they have the highest priority
|
||||||
|
logger.level("MESSAGE", no=61, color="<green>")
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.__class__.generation = partialmethod(logger.__class__.log, "GENERATION")
|
||||||
|
logger.__class__.prompt = partialmethod(logger.__class__.log, "PROMPT")
|
||||||
|
logger.__class__.init = partialmethod(logger.__class__.log, "INIT")
|
||||||
|
logger.__class__.init_ok = partialmethod(logger.__class__.log, "INIT_OK")
|
||||||
|
logger.__class__.init_warn = partialmethod(logger.__class__.log, "INIT_WARN")
|
||||||
|
logger.__class__.init_err = partialmethod(logger.__class__.log, "INIT_ERR")
|
||||||
|
logger.__class__.message = partialmethod(logger.__class__.log, "MESSAGE")
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"handlers": [
|
||||||
|
{"sink": sys.stderr, "format": logfmt, "colorize":True, "filter": is_stderr_log},
|
||||||
|
{"sink": sys.stdout, "format": genfmt, "level": "PROMPT", "colorize":True, "filter": is_stdout_log},
|
||||||
|
{"sink": sys.stdout, "format": initfmt, "level": "INIT", "colorize":True, "filter": is_init_log},
|
||||||
|
{"sink": sys.stdout, "format": msgfmt, "level": "MESSAGE", "colorize":True, "filter": is_msg_log}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
logger.configure(**config)
|
20
scripts/nataili/util/save_sample.py
Normal file
20
scripts/nataili/util/save_sample.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
def save_sample(image, filename, sample_path, extension='png', jpg_quality=95, webp_quality=95, webp_lossless=True, png_compression=9):
|
||||||
|
path = os.path.join(sample_path, filename + '.' + extension)
|
||||||
|
if os.path.exists(path):
|
||||||
|
return False
|
||||||
|
if not os.path.exists(sample_path):
|
||||||
|
os.makedirs(sample_path)
|
||||||
|
if extension == 'png':
|
||||||
|
image.save(path, format='PNG', compress_level=png_compression)
|
||||||
|
elif extension == 'jpg':
|
||||||
|
image.save(path, quality=jpg_quality, optimize=True)
|
||||||
|
elif extension == 'webp':
|
||||||
|
image.save(path, quality=webp_quality, lossless=webp_lossless)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
if os.path.exists(path):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
22
scripts/nataili/util/seed_to_int.py
Normal file
22
scripts/nataili/util/seed_to_int.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
def seed_to_int(s):
|
||||||
|
if type(s) is int:
|
||||||
|
return s
|
||||||
|
if s is None or s == '':
|
||||||
|
return random.randint(0, 2**32 - 1)
|
||||||
|
|
||||||
|
if type(s) is list:
|
||||||
|
seed_list = []
|
||||||
|
for seed in s:
|
||||||
|
if seed is None or seed == '':
|
||||||
|
seed_list.append(random.randint(0, 2**32 - 1))
|
||||||
|
else:
|
||||||
|
seed_list = s
|
||||||
|
|
||||||
|
return seed_list
|
||||||
|
|
||||||
|
n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1))
|
||||||
|
while n >= 2**32:
|
||||||
|
n = n >> 32
|
||||||
|
return n
|
15
streamlit_webview.py
Normal file
15
streamlit_webview.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import os, webview
|
||||||
|
from streamlit.web import bootstrap
|
||||||
|
from streamlit import config as _config
|
||||||
|
|
||||||
|
webview.create_window('Sygil', 'http://localhost:8501', width=1000, height=800, min_size=(500, 500))
|
||||||
|
webview.start()
|
||||||
|
|
||||||
|
dirname = os.path.dirname(__file__)
|
||||||
|
filename = os.path.join(dirname, 'scripts/webui_streamlit.py')
|
||||||
|
|
||||||
|
_config.set_option("server.headless", True)
|
||||||
|
args = []
|
||||||
|
|
||||||
|
#streamlit.cli.main_run(filename, args)
|
||||||
|
bootstrap.run(filename,'',args, flag_options={})
|
15
webui.cmd
15
webui.cmd
@ -1,17 +1,17 @@
|
|||||||
@echo off
|
@echo off
|
||||||
:: This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
:: This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||||
|
::
|
||||||
:: Copyright 2022 Sygil-Dev team.
|
:: Copyright 2022 Sygil-Dev team.
|
||||||
:: This program is free software: you can redistribute it and/or modify
|
:: This program is free software: you can redistribute it and/or modify
|
||||||
:: it under the terms of the GNU Affero General Public License as published by
|
:: it under the terms of the GNU Affero General Public License as published by
|
||||||
:: the Free Software Foundation, either version 3 of the License, or
|
:: the Free Software Foundation, either version 3 of the License, or
|
||||||
:: (at your option) any later version.
|
:: (at your option) any later version.
|
||||||
|
::
|
||||||
:: This program is distributed in the hope that it will be useful,
|
:: This program is distributed in the hope that it will be useful,
|
||||||
:: but WITHOUT ANY WARRANTY; without even the implied warranty of
|
:: but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
:: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
:: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
:: GNU Affero General Public License for more details.
|
:: GNU Affero General Public License for more details.
|
||||||
|
::
|
||||||
:: You should have received a copy of the GNU Affero General Public License
|
:: You should have received a copy of the GNU Affero General Public License
|
||||||
:: along with this program. If not, see <http://www.gnu.org/licenses/>.
|
:: along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
:: Run all commands using this script's directory as the working directory
|
:: Run all commands using this script's directory as the working directory
|
||||||
@ -102,12 +102,11 @@ call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%"
|
|||||||
|
|
||||||
:PROMPT
|
:PROMPT
|
||||||
set SETUPTOOLS_USE_DISTUTILS=stdlib
|
set SETUPTOOLS_USE_DISTUTILS=stdlib
|
||||||
IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" (
|
IF EXIST "models\ldm\stable-diffusion-v1\Stable Diffusion v1.5.ckpt" (
|
||||||
set "PYTHONPATH=%~dp0"
|
python -m streamlit run scripts\webui_streamlit.py --theme.base dark --server.address localhost
|
||||||
python scripts\relauncher.py %*
|
|
||||||
) ELSE (
|
) ELSE (
|
||||||
echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'.
|
echo Your model file does not exist! Once the WebUI launches please visit the Model Manager page and download the models by using the Download button for each model.
|
||||||
pause
|
python -m streamlit run scripts\webui_streamlit.py --theme.base dark --server.address localhost
|
||||||
)
|
)
|
||||||
|
|
||||||
::cmd /k
|
::cmd /k
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
@echo off
|
@echo off
|
||||||
:: This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
:: This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||||
::
|
|
||||||
:: Copyright 2022 Sygil-Dev team.
|
:: Copyright 2022 Sygil-Dev team.
|
||||||
:: This program is free software: you can redistribute it and/or modify
|
:: This program is free software: you can redistribute it and/or modify
|
||||||
:: it under the terms of the GNU Affero General Public License as published by
|
:: it under the terms of the GNU Affero General Public License as published by
|
||||||
:: the Free Software Foundation, either version 3 of the License, or
|
:: the Free Software Foundation, either version 3 of the License, or
|
||||||
:: (at your option) any later version.
|
:: (at your option) any later version.
|
||||||
::
|
|
||||||
:: This program is distributed in the hope that it will be useful,
|
:: This program is distributed in the hope that it will be useful,
|
||||||
:: but WITHOUT ANY WARRANTY; without even the implied warranty of
|
:: but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
:: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
:: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
:: GNU Affero General Public License for more details.
|
:: GNU Affero General Public License for more details.
|
||||||
::
|
|
||||||
:: You should have received a copy of the GNU Affero General Public License
|
:: You should have received a copy of the GNU Affero General Public License
|
||||||
:: along with this program. If not, see <http://www.gnu.org/licenses/>.
|
:: along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
:: Run all commands using this script's directory as the working directory
|
:: Run all commands using this script's directory as the working directory
|
||||||
@ -99,7 +99,8 @@ call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%"
|
|||||||
:PROMPT
|
:PROMPT
|
||||||
set SETUPTOOLS_USE_DISTUTILS=stdlib
|
set SETUPTOOLS_USE_DISTUTILS=stdlib
|
||||||
IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" (
|
IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" (
|
||||||
python -m streamlit run scripts\webui_streamlit.py --theme.base dark --server.address localhost
|
set "PYTHONPATH=%~dp0"
|
||||||
|
python scripts\relauncher.py %*
|
||||||
) ELSE (
|
) ELSE (
|
||||||
echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'.
|
echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'.
|
||||||
pause
|
pause
|
Loading…
Reference in New Issue
Block a user