diff --git a/requirements.txt b/requirements.txt
index 6f69981..3247e3e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -56,11 +56,11 @@ timm==0.6.7
tqdm==4.64.0
tensorboard==2.10.1
-
# Other
retry==0.9.2 # used by sd_utils
python-slugify==6.1.2 # used by sd_utils
piexif==1.1.3 # used by sd_utils
+pywebview==3.6.3 # used by streamlit_webview.py
accelerate==0.12.0
albumentations==0.4.3
diff --git a/scripts/ModelManager.py b/scripts/ModelManager.py
index 07abc07..3f86ca1 100644
--- a/scripts/ModelManager.py
+++ b/scripts/ModelManager.py
@@ -45,12 +45,12 @@ def download_file(file_name, file_path, file_url):
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
try:
- with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token), stream=True) as r:
+ with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token) if "huggingface.co" in file_url else None, stream=True) as r:
r.raise_for_status()
with open(os.path.join(file_path, file_name), 'wb') as f:
for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"):
f.write(chunk)
- except HTTPError:
+ except HTTPError as e:
if "huggingface.co" in file_url:
if "resolve"in file_url:
repo_url = file_url.split("resolve")[0]
@@ -59,9 +59,12 @@ def download_file(file_name, file_path, file_url):
f"You need to accept the license for the model in order to be able to download it. "
f"Please visit {repo_url} and accept the lincense there, then try again to download the model.")
+ logger.error(e)
+
else:
print(file_name + ' already exists.')
+
def download_model(models, model_name):
""" Download all files from model_list[model_name] """
for file in models[model_name]:
diff --git a/scripts/img2txt.py b/scripts/img2txt.py
index f9a7f44..72f260a 100644
--- a/scripts/img2txt.py
+++ b/scripts/img2txt.py
@@ -66,6 +66,9 @@ st.session_state["log"] = []
def load_blip_model():
logger.info("Loading BLIP Model")
+ if "log" not in st.session_state:
+ st.session_state["log"] = []
+
st.session_state["log"].append("Loading BLIP Model")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
@@ -232,7 +235,7 @@ def interrogate(image, models):
for best in bests:
best.sort(key=lambda x: x[1], reverse=True)
- # prune to 3
+ # prune to 3
best = best[:3]
row = [model_name]
@@ -326,7 +329,7 @@ def img2txt():
def layout():
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
#st.info("Under Construction. :construction_worker:")
- #
+ #
if "clip_models" not in server_state:
server_state["clip_models"] = {}
if "preprocesses" not in server_state:
@@ -397,7 +400,9 @@ def layout():
with col2:
st.subheader("Image")
- refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
+ image_col1, image_col2 = st.columns([10,25])
+ with image_col1:
+ refresh = st.form_submit_button("Update Preview Image", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
if st.session_state["uploaded_image"]:
#print (type(st.session_state["uploaded_image"]))
@@ -436,11 +441,12 @@ def layout():
#st.session_state["input_image_preview"].code('', language="")
st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
- #
- # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
- # generate_col1.title("")
- # generate_col1.title("")
- generate_button = st.form_submit_button("Generate!")
+ with image_col2:
+ #
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ # generate_col1.title("")
+ # generate_col1.title("")
+ generate_button = st.form_submit_button("Generate!", help="Start interrogating the images to generate a prompt from each of the selected images")
if generate_button:
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
diff --git a/scripts/nataili/__init__.py b/scripts/nataili/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/scripts/nataili/inference/__init__.py b/scripts/nataili/inference/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/scripts/nataili/inference/compvis/__init__.py b/scripts/nataili/inference/compvis/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/scripts/nataili/inference/compvis/img2img.py b/scripts/nataili/inference/compvis/img2img.py
new file mode 100644
index 0000000..d55474b
--- /dev/null
+++ b/scripts/nataili/inference/compvis/img2img.py
@@ -0,0 +1,551 @@
+import os
+import re
+import sys
+import k_diffusion as K
+import tqdm
+from contextlib import contextmanager, nullcontext
+import skimage
+import numpy as np
+import PIL
+import torch
+from einops import rearrange
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.kdiffusion import CFGMaskedDenoiser, KDiffusionSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from nataili.util.cache import torch_gc
+from nataili.util.check_prompt_length import check_prompt_length
+from nataili.util.get_next_sequence_number import get_next_sequence_number
+from nataili.util.image_grid import image_grid
+from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip
+from nataili.util.save_sample import save_sample
+from nataili.util.seed_to_int import seed_to_int
+from slugify import slugify
+import PIL
+
+
+class img2img:
+ def __init__(self, model, device, output_dir, save_extension='jpg',
+ output_file_path=False, load_concepts=False, concepts_dir=None,
+ verify_input=True, auto_cast=True):
+ self.model = model
+ self.output_dir = output_dir
+ self.output_file_path = output_file_path
+ self.save_extension = save_extension
+ self.load_concepts = load_concepts
+ self.concepts_dir = concepts_dir
+ self.verify_input = verify_input
+ self.auto_cast = auto_cast
+ self.device = device
+ self.comments = []
+ self.output_images = []
+ self.info = ''
+ self.stats = ''
+ self.images = []
+
+ def create_random_tensors(self, shape, seeds):
+ xs = []
+ for seed in seeds:
+ torch.manual_seed(seed)
+
+ # randn results depend on device; gpu and cpu get different results for same seed;
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
+ # but the original script had it like this so i do not dare change it for now because
+ # it will break everyone's seeds.
+ xs.append(torch.randn(shape, device=self.device))
+ x = torch.stack(xs)
+ return x
+
+ def process_prompt_tokens(self, prompt_tokens):
+ # compviz codebase
+ tokenizer = self.model.cond_stage_model.tokenizer
+ text_encoder = self.model.cond_stage_model.transformer
+
+ # diffusers codebase
+ #tokenizer = pipe.tokenizer
+ #text_encoder = pipe.text_encoder
+
+ ext = ('.pt', '.bin')
+ for token_name in prompt_tokens:
+ embedding_path = os.path.join(self.concepts_dir, token_name)
+ if os.path.exists(embedding_path):
+ for files in os.listdir(embedding_path):
+ if files.endswith(ext):
+ load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
+ else:
+ print(f"Concept {token_name} not found in {self.concepts_dir}")
+ del tokenizer, text_encoder
+ return
+ del tokenizer, text_encoder
+
+ def resize_image(self, resize_mode, im, width, height):
+ LANCZOS = (PIL.Image.Resampling.LANCZOS if hasattr(PIL.Image, 'Resampling') else PIL.Image.LANCZOS)
+ if resize_mode == "resize":
+ res = im.resize((width, height), resample=LANCZOS)
+ elif resize_mode == "crop":
+ ratio = width / height
+ src_ratio = im.width / im.height
+
+ src_w = width if ratio > src_ratio else im.width * height // im.height
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
+
+ resized = im.resize((src_w, src_h), resample=LANCZOS)
+ res = PIL.Image.new("RGBA", (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+ else:
+ ratio = width / height
+ src_ratio = im.width / im.height
+
+ src_w = width if ratio < src_ratio else im.width * height // im.height
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
+
+ resized = im.resize((src_w, src_h), resample=LANCZOS)
+ res = PIL.Image.new("RGBA", (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+
+ if ratio < src_ratio:
+ fill_height = height // 2 - src_h // 2
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ elif ratio > src_ratio:
+ fill_width = width // 2 - src_w // 2
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+
+ return res
+
+ #
+ # helper fft routines that keep ortho normalization and auto-shift before and after fft
+ def _fft2(self, data):
+ if data.ndim > 2: # has channels
+ out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
+ for c in range(data.shape[2]):
+ c_data = data[:,:,c]
+ out_fft[:,:,c] = np.fft.fft2(np.fft.fftshift(c_data),norm="ortho")
+ out_fft[:,:,c] = np.fft.ifftshift(out_fft[:,:,c])
+ else: # one channel
+ out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
+ out_fft[:,:] = np.fft.fft2(np.fft.fftshift(data),norm="ortho")
+ out_fft[:,:] = np.fft.ifftshift(out_fft[:,:])
+
+ return out_fft
+
+ def _ifft2(self, data):
+ if data.ndim > 2: # has channels
+ out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
+ for c in range(data.shape[2]):
+ c_data = data[:,:,c]
+ out_ifft[:,:,c] = np.fft.ifft2(np.fft.fftshift(c_data),norm="ortho")
+ out_ifft[:,:,c] = np.fft.ifftshift(out_ifft[:,:,c])
+ else: # one channel
+ out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
+ out_ifft[:,:] = np.fft.ifft2(np.fft.fftshift(data),norm="ortho")
+ out_ifft[:,:] = np.fft.ifftshift(out_ifft[:,:])
+
+ return out_ifft
+
+ def _get_gaussian_window(self, width, height, std=3.14, mode=0):
+
+ window_scale_x = float(width / min(width, height))
+ window_scale_y = float(height / min(width, height))
+
+ window = np.zeros((width, height))
+ x = (np.arange(width) / width * 2. - 1.) * window_scale_x
+ for y in range(height):
+ fy = (y / height * 2. - 1.) * window_scale_y
+ if mode == 0:
+ window[:, y] = np.exp(-(x**2+fy**2) * std)
+ else:
+ window[:, y] = (1/((x**2+1.) * (fy**2+1.))) ** (std/3.14) # hey wait a minute that's not gaussian
+
+ return window
+
+ def _get_masked_window_rgb(self, np_mask_grey, hardness=1.):
+ np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
+ if hardness != 1.:
+ hardened = np_mask_grey[:] ** hardness
+ else:
+ hardened = np_mask_grey[:]
+ for c in range(3):
+ np_mask_rgb[:,:,c] = hardened[:]
+ return np_mask_rgb
+
+ def get_matched_noise(self, _np_src_image, np_mask_rgb, noise_q, color_variation):
+ """
+ Explanation:
+ Getting good results in/out-painting with stable diffusion can be challenging.
+ Although there are simpler effective solutions for in-painting, out-painting can be especially challenging because there is no color data
+ in the masked area to help prompt the generator. Ideally, even for in-painting we'd like work effectively without that data as well.
+ Provided here is my take on a potential solution to this problem.
+
+ By taking a fourier transform of the masked src img we get a function that tells us the presence and orientation of each feature scale in the unmasked src.
+ Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations, and positions increases output coherence
+ by helping keep features aligned. This technique is applicable to any continuous generation task such as audio or video, each of which can
+ be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased. For multi-channel data such as color
+ or stereo sound the "color tone" or histogram of the seed noise can be matched to improve quality (using scikit-image currently)
+ This method is quite robust and has the added benefit of being fast independently of the size of the out-painted area.
+ The effects of this method include things like helping the generator integrate the pre-existing view distance and camera angle.
+
+ Carefully managing color and brightness with histogram matching is also essential to achieving good coherence.
+
+ noise_q controls the exponent in the fall-off of the distribution can be any positive number, lower values means higher detail (range > 0, default 1.)
+ color_variation controls how much freedom is allowed for the colors/palette of the out-painted area (range 0..1, default 0.01)
+ This code is provided as is under the Unlicense (https://unlicense.org/)
+ Although you have no obligation to do so, if you found this code helpful please find it in your heart to credit me [parlance-zz].
+
+ Questions or comments can be sent to parlance@fifth-harmonic.com (https://github.com/parlance-zz/)
+ This code is part of a new branch of a discord bot I am working on integrating with diffusers (https://github.com/parlance-zz/g-diffuser-bot)
+
+ """
+
+ global DEBUG_MODE
+ global TMP_ROOT_PATH
+
+ width = _np_src_image.shape[0]
+ height = _np_src_image.shape[1]
+ num_channels = _np_src_image.shape[2]
+
+ np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
+ np_mask_grey = (np.sum(np_mask_rgb, axis=2)/3.)
+ np_src_grey = (np.sum(np_src_image, axis=2)/3.)
+ all_mask = np.ones((width, height), dtype=bool)
+ img_mask = np_mask_grey > 1e-6
+ ref_mask = np_mask_grey < 1e-3
+
+ windowed_image = _np_src_image * (1.-self._get_masked_window_rgb(np_mask_grey))
+ windowed_image /= np.max(windowed_image)
+ windowed_image += np.average(_np_src_image) * np_mask_rgb# / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
+ #windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) / (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area
+ #_save_debug_img(windowed_image, "windowed_src_img")
+
+ src_fft = self._fft2(windowed_image) # get feature statistics from masked src img
+ src_dist = np.absolute(src_fft)
+ src_phase = src_fft / src_dist
+ #_save_debug_img(src_dist, "windowed_src_dist")
+
+ noise_window = self._get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
+ noise_rgb = np.random.random_sample((width, height, num_channels))
+ noise_grey = (np.sum(noise_rgb, axis=2)/3.)
+ noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
+ for c in range(num_channels):
+ noise_rgb[:,:,c] += (1. - color_variation) * noise_grey
+
+ noise_fft = self._fft2(noise_rgb)
+ for c in range(num_channels):
+ noise_fft[:,:,c] *= noise_window
+ noise_rgb = np.real(self._ifft2(noise_fft))
+ shaped_noise_fft = self._fft2(noise_rgb)
+ shaped_noise_fft[:,:,:] = np.absolute(shaped_noise_fft[:,:,:])**2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
+
+ brightness_variation = 0.#color_variation # todo: temporarily tieing brightness variation to color variation for now
+ contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
+
+ # scikit-image is used for histogram matching, very convenient!
+ shaped_noise = np.real(self._ifft2(shaped_noise_fft))
+ shaped_noise -= np.min(shaped_noise)
+ shaped_noise /= np.max(shaped_noise)
+ shaped_noise[img_mask,:] = skimage.exposure.match_histograms(shaped_noise[img_mask,:]**1., contrast_adjusted_np_src[ref_mask,:], channel_axis=1)
+ shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
+ #_save_debug_img(shaped_noise, "shaped_noise")
+
+ matched_noise = np.zeros((width, height, num_channels))
+ matched_noise = shaped_noise[:]
+ #matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:], _np_src_image[ref_mask,:], channel_axis=1)
+ #matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb
+
+ #_save_debug_img(matched_noise, "matched_noise")
+
+ """
+ todo:
+ color_variation doesnt have to be a single number, the overall color tone of the out-painted area could be param controlled
+ """
+
+ return np.clip(matched_noise, 0., 1.)
+
+ def find_noise_for_image(self, model, device, init_image, prompt, steps=200, cond_scale=2.0, verbose=False, normalize=False, generation_callback=None):
+ image = np.array(init_image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ image = 2. * image - 1.
+ image = image.to(device)
+ x = model.get_first_stage_encoding(model.encode_first_stage(image))
+
+ uncond = model.get_learned_conditioning([''])
+ cond = model.get_learned_conditioning([prompt])
+
+ s_in = x.new_ones([x.shape[0]])
+ dnw = K.external.CompVisDenoiser(model)
+ sigmas = dnw.get_sigmas(steps).flip(0)
+
+ if verbose:
+ print(sigmas)
+
+ for i in tqdm.trange(1, len(sigmas)):
+ x_in = torch.cat([x] * 2)
+ sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
+ cond_in = torch.cat([uncond, cond])
+
+ c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
+
+ if i == 1:
+ t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
+ else:
+ t = dnw.sigma_to_t(sigma_in)
+
+ eps = model.apply_model(x_in * c_in, t, cond=cond_in)
+ denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
+
+ denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale
+
+ if i == 1:
+ d = (x - denoised) / (2 * sigmas[i])
+ else:
+ d = (x - denoised) / sigmas[i - 1]
+
+ dt = sigmas[i] - sigmas[i - 1]
+ x = x + d * dt
+
+ return x / sigmas[-1]
+
+ def generate(self, prompt: str, init_img=None, init_mask=None, mask_mode='mask', resize_mode='resize', noise_mode='seed',
+ denoising_strength:float=0.8, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None,
+ height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0):
+ seed = seed_to_int(seed)
+ image_dict = {
+ "seed": seed
+ }
+ # Init image is assumed to be a PIL image
+ init_img = self.resize_image('resize', init_img, width, height)
+ if sampler_name == 'PLMS':
+ sampler = PLMSSampler(self.model)
+ elif sampler_name == 'DDIM':
+ sampler = DDIMSampler(self.model)
+ elif sampler_name == 'k_dpm_2_a':
+ sampler = KDiffusionSampler(self.model,'dpm_2_ancestral')
+ elif sampler_name == 'k_dpm_2':
+ sampler = KDiffusionSampler(self.model,'dpm_2')
+ elif sampler_name == 'k_euler_a':
+ sampler = KDiffusionSampler(self.model,'euler_ancestral')
+ elif sampler_name == 'k_euler':
+ sampler = KDiffusionSampler(self.model,'euler')
+ elif sampler_name == 'k_heun':
+ sampler = KDiffusionSampler(self.model,'heun')
+ elif sampler_name == 'k_lms':
+ sampler = KDiffusionSampler(self.model,'lms')
+ else:
+ raise Exception("Unknown sampler: " + sampler_name)
+
+ torch_gc()
+ def process_init_mask(init_mask: PIL.Image):
+ if init_mask.mode == "RGBA":
+ init_mask = init_mask.convert('RGBA')
+ background = PIL.Image.new('RGBA', init_mask.size, (0, 0, 0))
+ init_mask = PIL.Image.alpha_composite(background, init_mask)
+ init_mask = init_mask.convert('RGB')
+ return init_mask
+
+ if mask_mode == "mask":
+ if init_mask:
+ init_mask = process_init_mask(init_mask)
+ elif mask_mode == "invert":
+ if init_mask:
+ init_mask = process_init_mask(init_mask)
+ init_mask = PIL.ImageOps.invert(init_mask)
+ elif mask_mode == "alpha":
+ init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
+ init_mask = init_img_transparency
+ init_mask = init_mask.convert("RGB")
+ init_mask = self.resize_image(resize_mode, init_mask, width, height)
+ init_mask = init_mask.convert("RGB")
+
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
+ t_enc = int(denoising_strength * ddim_steps)
+
+ if init_mask is not None and (noise_mode == "matched" or noise_mode == "find_and_matched") and init_img is not None:
+ noise_q = 0.99
+ color_variation = 0.0
+ mask_blend_factor = 1.0
+
+ np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
+ np_mask_rgb = 1. - (np.asarray(PIL.ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
+ np_mask_rgb -= np.min(np_mask_rgb)
+ np_mask_rgb /= np.max(np_mask_rgb)
+ np_mask_rgb = 1. - np_mask_rgb
+ np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
+ blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
+ blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
+ #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
+ #np_mask_rgb = np_mask_rgb + blurred
+ np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
+ np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
+
+ noise_rgb = self.get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
+ blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
+ noised = noise_rgb[:]
+ blend_mask_rgb **= (2.)
+ noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
+
+ np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
+ ref_mask = np_mask_grey < 1e-3
+
+ all_mask = np.ones((height, width), dtype=bool)
+ noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
+
+ init_img = PIL.Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
+
+ def init():
+ image = init_img.convert('RGB')
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+
+ mask_channel = None
+ if init_mask:
+ alpha = self.resize_image(resize_mode, init_mask, width // 8, height // 8)
+ mask_channel = alpha.split()[-1]
+
+ mask = None
+ if mask_channel is not None:
+ mask = np.array(mask_channel).astype(np.float32) / 255.0
+ mask = (1 - mask)
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3)
+ mask = torch.from_numpy(mask).to(self.model.device)
+
+ init_image = 2. * image - 1.
+ init_image = init_image.to(self.model.device)
+ init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
+
+ return init_latent, mask,
+
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ t_enc_steps = t_enc
+ obliterate = False
+ if ddim_steps == t_enc_steps:
+ t_enc_steps = t_enc_steps - 1
+ obliterate = True
+
+ if sampler_name != 'DDIM':
+ x0, z_mask = init_data
+
+ sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
+ noise = x * sigmas[ddim_steps - t_enc_steps - 1]
+
+ xi = x0 + noise
+
+ # Obliterate masked image
+ if z_mask is not None and obliterate:
+ random = torch.randn(z_mask.shape, device=xi.device)
+ xi = (z_mask * noise) + ((1-z_mask) * xi)
+
+ sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
+ model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
+ samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
+ extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
+ 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
+ else:
+
+ x0, z_mask = init_data
+
+ sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
+ z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(self.model.device))
+
+ # Obliterate masked image
+ if z_mask is not None and obliterate:
+ random = torch.randn(z_mask.shape, device=z_enc.device)
+ z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
+
+ # decode it
+ samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ z_mask=z_mask, x0=x0)
+ return samples_ddim
+
+ torch_gc()
+
+ if self.load_concepts and self.concepts_dir is not None:
+ prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
+ if prompt_tokens:
+ self.process_prompt_tokens(prompt_tokens)
+
+ os.makedirs(self.output_dir, exist_ok=True)
+
+ sample_path = os.path.join(self.output_dir, "samples")
+ os.makedirs(sample_path, exist_ok=True)
+
+ if self.verify_input:
+ try:
+ check_prompt_length(self.model, prompt, self.comments)
+ except:
+ import traceback
+ print("Error verifying input:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ all_prompts = batch_size * n_iter * [prompt]
+ all_seeds = [seed + x for x in range(len(all_prompts))]
+
+ precision_scope = torch.autocast if self.auto_cast else nullcontext
+
+ with torch.no_grad(), precision_scope("cuda"):
+ for n in range(n_iter):
+ print(f"Iteration: {n+1}/{n_iter}")
+ prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
+ seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
+
+ uc = self.model.get_learned_conditioning(len(prompts) * [''])
+
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+
+ c = self.model.get_learned_conditioning(prompts)
+
+ opt_C = 4
+ opt_f = 8
+ shape = [opt_C, height // opt_f, width // opt_f]
+
+ x = self.create_random_tensors(shape, seeds=seeds)
+ init_data = init()
+ samples_ddim = sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
+
+ x_samples_ddim = self.model.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for i, x_sample in enumerate(x_samples_ddim):
+ sanitized_prompt = slugify(prompts[i])
+ full_path = os.path.join(os.getcwd(), sample_path)
+ sample_path_i = sample_path
+ base_count = get_next_sequence_number(sample_path_i)
+ filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)]
+
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = x_sample.astype(np.uint8)
+ image = PIL.Image.fromarray(x_sample)
+ image_dict['image'] = image
+ self.images.append(image_dict)
+
+ if save_individual_images:
+ path = os.path.join(sample_path, filename + '.' + self.save_extension)
+ success = save_sample(image, filename, sample_path_i, self.save_extension)
+ if success:
+ if self.output_file_path:
+ self.output_images.append(path)
+ else:
+ self.output_images.append(image)
+ else:
+ return
+
+ self.info = f"""
+ {prompt}
+ Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}
+ """.strip()
+ self.stats = f'''
+ '''
+
+ for comment in self.comments:
+ self.info += "\n\n" + comment
+
+ torch_gc()
+
+ del sampler
+
+ return
diff --git a/scripts/nataili/inference/compvis/txt2img.py b/scripts/nataili/inference/compvis/txt2img.py
new file mode 100644
index 0000000..cb11c67
--- /dev/null
+++ b/scripts/nataili/inference/compvis/txt2img.py
@@ -0,0 +1,201 @@
+import os
+import re
+import sys
+from contextlib import contextmanager, nullcontext
+
+import numpy as np
+import PIL
+import torch
+from einops import rearrange
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.kdiffusion import KDiffusionSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from nataili.util.cache import torch_gc
+from nataili.util.check_prompt_length import check_prompt_length
+from nataili.util.get_next_sequence_number import get_next_sequence_number
+from nataili.util.image_grid import image_grid
+from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip
+from nataili.util.save_sample import save_sample
+from nataili.util.seed_to_int import seed_to_int
+from slugify import slugify
+
+
+class txt2img:
+ def __init__(self, model, device, output_dir, save_extension='jpg',
+ output_file_path=False, load_concepts=False, concepts_dir=None,
+ verify_input=True, auto_cast=True):
+ self.model = model
+ self.output_dir = output_dir
+ self.output_file_path = output_file_path
+ self.save_extension = save_extension
+ self.load_concepts = load_concepts
+ self.concepts_dir = concepts_dir
+ self.verify_input = verify_input
+ self.auto_cast = auto_cast
+ self.device = device
+ self.comments = []
+ self.output_images = []
+ self.info = ''
+ self.stats = ''
+ self.images = []
+
+ def create_random_tensors(self, shape, seeds):
+ xs = []
+ for seed in seeds:
+ torch.manual_seed(seed)
+
+ # randn results depend on device; gpu and cpu get different results for same seed;
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
+ # but the original script had it like this so i do not dare change it for now because
+ # it will break everyone's seeds.
+ xs.append(torch.randn(shape, device=self.device))
+ x = torch.stack(xs)
+ return x
+
+ def process_prompt_tokens(self, prompt_tokens):
+ # compviz codebase
+ tokenizer = self.model.cond_stage_model.tokenizer
+ text_encoder = self.model.cond_stage_model.transformer
+
+ # diffusers codebase
+ #tokenizer = pipe.tokenizer
+ #text_encoder = pipe.text_encoder
+
+ ext = ('.pt', '.bin')
+ for token_name in prompt_tokens:
+ embedding_path = os.path.join(self.concepts_dir, token_name)
+ if os.path.exists(embedding_path):
+ for files in os.listdir(embedding_path):
+ if files.endswith(ext):
+ load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
+ else:
+ print(f"Concept {token_name} not found in {self.concepts_dir}")
+ del tokenizer, text_encoder
+ return
+ del tokenizer, text_encoder
+
+ def generate(self, prompt: str, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None,
+ height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0):
+ seed = seed_to_int(seed)
+
+ image_dict = {
+ "seed": seed
+ }
+ negprompt = ''
+ if '###' in prompt:
+ prompt, negprompt = prompt.split('###', 1)
+ prompt = prompt.strip()
+ negprompt = negprompt.strip()
+
+ if sampler_name == 'PLMS':
+ sampler = PLMSSampler(self.model)
+ elif sampler_name == 'DDIM':
+ sampler = DDIMSampler(self.model)
+ elif sampler_name == 'k_dpm_2_a':
+ sampler = KDiffusionSampler(self.model,'dpm_2_ancestral')
+ elif sampler_name == 'k_dpm_2':
+ sampler = KDiffusionSampler(self.model,'dpm_2')
+ elif sampler_name == 'k_euler_a':
+ sampler = KDiffusionSampler(self.model,'euler_ancestral')
+ elif sampler_name == 'k_euler':
+ sampler = KDiffusionSampler(self.model,'euler')
+ elif sampler_name == 'k_heun':
+ sampler = KDiffusionSampler(self.model,'heun')
+ elif sampler_name == 'k_lms':
+ sampler = KDiffusionSampler(self.model,'lms')
+ else:
+ raise Exception("Unknown sampler: " + sampler_name)
+
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning, x_T=x)
+ return samples_ddim
+
+ torch_gc()
+
+ if self.load_concepts and self.concepts_dir is not None:
+ prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
+ if prompt_tokens:
+ self.process_prompt_tokens(prompt_tokens)
+
+ os.makedirs(self.output_dir, exist_ok=True)
+
+ sample_path = os.path.join(self.output_dir, "samples")
+ os.makedirs(sample_path, exist_ok=True)
+
+ if self.verify_input:
+ try:
+ check_prompt_length(self.model, prompt, self.comments)
+ except:
+ import traceback
+ print("Error verifying input:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ all_prompts = batch_size * n_iter * [prompt]
+ all_seeds = [seed + x for x in range(len(all_prompts))]
+
+ precision_scope = torch.autocast if self.auto_cast else nullcontext
+
+ with torch.no_grad(), precision_scope("cuda"):
+ for n in range(n_iter):
+ print(f"Iteration: {n+1}/{n_iter}")
+ prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
+ seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
+
+ uc = self.model.get_learned_conditioning(len(prompts) * [negprompt])
+
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+
+ c = self.model.get_learned_conditioning(prompts)
+
+ opt_C = 4
+ opt_f = 8
+ shape = [opt_C, height // opt_f, width // opt_f]
+
+ x = self.create_random_tensors(shape, seeds=seeds)
+
+ samples_ddim = sample(init_data=None, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
+
+ x_samples_ddim = self.model.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for i, x_sample in enumerate(x_samples_ddim):
+ sanitized_prompt = slugify(prompts[i])
+ full_path = os.path.join(os.getcwd(), sample_path)
+ sample_path_i = sample_path
+ base_count = get_next_sequence_number(sample_path_i)
+ filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)]
+
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = x_sample.astype(np.uint8)
+ image = PIL.Image.fromarray(x_sample)
+ image_dict['image'] = image
+ self.images.append(image_dict)
+
+ if save_individual_images:
+ path = os.path.join(sample_path, filename + '.' + self.save_extension)
+ success = save_sample(image, filename, sample_path_i, self.save_extension)
+ if success:
+ if self.output_file_path:
+ self.output_images.append(path)
+ else:
+ self.output_images.append(image)
+ else:
+ return
+
+ self.info = f"""
+ {prompt}
+ Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}
+ """.strip()
+ self.stats = f'''
+ '''
+
+ for comment in self.comments:
+ self.info += "\n\n" + comment
+
+ torch_gc()
+
+ del sampler
+
+ return
diff --git a/scripts/nataili/inference/diffusers/__init__.py b/scripts/nataili/inference/diffusers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/scripts/nataili/model_manager.py b/scripts/nataili/model_manager.py
new file mode 100644
index 0000000..ee3bae1
--- /dev/null
+++ b/scripts/nataili/model_manager.py
@@ -0,0 +1,458 @@
+import os
+import json
+import shutil
+import zipfile
+import requests
+import git
+import torch
+import hashlib
+from ldm.util import instantiate_from_config
+from omegaconf import OmegaConf
+from transformers import logging
+
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from gfpgan import GFPGANer
+from realesrgan import RealESRGANer
+from ldm.models.blip import blip_decoder
+from tqdm import tqdm
+import open_clip
+import clip
+
+from nataili.util.cache import torch_gc
+from nataili.util import logger
+
+logging.set_verbosity_error()
+
+models = json.load(open('./db.json'))
+dependencies = json.load(open('./db_dep.json'))
+remote_models = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db.json"
+remote_dependencies = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db_dep.json"
+
+class ModelManager():
+ def __init__(self, hf_auth=None, download=True):
+ if download:
+ try:
+ logger.init("Model Reference", status="Downloading")
+ r = requests.get(remote_models)
+ self.models = r.json()
+ r = requests.get(remote_dependencies)
+ self.dependencies = json.load(open('./db_dep.json'))
+ logger.init_ok("Model Reference", status="OK")
+ except:
+ logger.init_err("Model Reference", status="Download Error")
+ self.models = json.load(open('./db.json'))
+ self.dependencies = json.load(open('./db_dep.json'))
+ logger.init_warn("Model Reference", status="Local")
+ self.available_models = []
+ self.tainted_models = []
+ self.available_dependencies = []
+ self.loaded_models = {}
+ self.hf_auth = None
+ self.set_authentication(hf_auth)
+
+ def init(self):
+ dependencies_available = []
+ for dependency in self.dependencies:
+ if self.check_available(self.get_dependency_files(dependency)):
+ dependencies_available.append(dependency)
+ self.available_dependencies = dependencies_available
+
+ models_available = []
+ for model in self.models:
+ if self.check_available(self.get_model_files(model)):
+ models_available.append(model)
+ self.available_models = models_available
+
+ if self.hf_auth is not None:
+ if 'username' not in self.hf_auth and 'password' not in self.hf_auth:
+ raise ValueError('hf_auth must contain username and password')
+ else:
+ if self.hf_auth['username'] == '' or self.hf_auth['password'] == '':
+ raise ValueError('hf_auth must contain username and password')
+ return True
+
+ def set_authentication(self, hf_auth=None):
+ # We do not let No authentication override previously set auth
+ if not hf_auth and self.hf_auth:
+ return
+ self.hf_auth = hf_auth
+
+ def get_model(self, model_name):
+ return self.models.get(model_name)
+
+ def get_filtered_models(self, **kwargs):
+ '''Get all model names.
+ Can filter based on metadata of the model reference db
+ '''
+ filtered_models = self.models
+ for keyword in kwargs:
+ iterating_models = filtered_models.copy()
+ filtered_models = {}
+ for model in iterating_models:
+ # logger.debug([keyword,iterating_models[model].get(keyword),kwargs[keyword]])
+ if iterating_models[model].get(keyword) == kwargs[keyword]:
+ filtered_models[model] = iterating_models[model]
+ return filtered_models
+
+ def get_filtered_model_names(self, **kwargs):
+ filtered_models = self.get_filtered_models(**kwargs)
+ return list(filtered_models.keys())
+
+ def get_dependency(self, dependency_name):
+ return self.dependencies[dependency_name]
+
+ def get_model_files(self, model_name):
+ return self.models[model_name]['config']['files']
+
+ def get_dependency_files(self, dependency_name):
+ return self.dependencies[dependency_name]['config']['files']
+
+ def get_model_download(self, model_name):
+ return self.models[model_name]['config']['download']
+
+ def get_dependency_download(self, dependency_name):
+ return self.dependencies[dependency_name]['config']['download']
+
+ def get_available_models(self):
+ return self.available_models
+
+ def get_available_dependencies(self):
+ return self.available_dependencies
+
+ def get_loaded_models(self):
+ return self.loaded_models
+
+ def get_loaded_models_names(self):
+ return list(self.loaded_models.keys())
+
+ def get_loaded_model(self, model_name):
+ return self.loaded_models[model_name]
+
+ def unload_model(self, model_name):
+ if model_name in self.loaded_models:
+ del self.loaded_models[model_name]
+ return True
+ return False
+
+ def unload_all_models(self):
+ for model in self.loaded_models:
+ del self.loaded_models[model]
+ return True
+
+ def taint_model(self,model_name):
+ '''Marks a model as not valid by remiving it from available_models'''
+ if model_name in self.available_models:
+ self.available_models.remove(model_name)
+ self.tainted_models.append(model_name)
+
+ def taint_models(self, models):
+ for model in models:
+ self.taint_model(model)
+
+ def load_model_from_config(self, model_path='', config_path='', map_location="cpu"):
+ config = OmegaConf.load(config_path)
+ pl_sd = torch.load(model_path, map_location=map_location)
+ if "global_step" in pl_sd:
+ logger.info(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ model = model.eval()
+ del pl_sd, sd, m, u
+ return model
+
+ def load_ckpt(self, model_name='', precision='half', gpu_id=0):
+ ckpt_path = self.get_model_files(model_name)[0]['path']
+ config_path = self.get_model_files(model_name)[1]['path']
+ model = self.load_model_from_config(model_path=ckpt_path, config_path=config_path)
+ device = torch.device(f"cuda:{gpu_id}")
+ model = (model if precision=='full' else model.half()).to(device)
+ torch_gc()
+ return {'model': model, 'device': device}
+
+ def load_realesrgan(self, model_name='', precision='half', gpu_id=0):
+
+ RealESRGAN_models = {
+ 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
+ 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+ }
+
+ model_path = self.get_model_files(model_name)[0]['path']
+ device = torch.device(f"cuda:{gpu_id}")
+ model = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[models[model_name]['name']],
+ pre_pad=0, half=True if precision == 'half' else False, device=device)
+ return {'model': model, 'device': device}
+
+ def load_gfpgan(self, model_name='', gpu_id=0):
+
+ model_path = self.get_model_files(model_name)[0]['path']
+ device = torch.device(f"cuda:{gpu_id}")
+ model = GFPGANer(model_path=model_path, upscale=1, arch='clean',
+ channel_multiplier=2, bg_upsampler=None, device=device)
+ return {'model': model, 'device': device}
+
+ def load_blip(self, model_name='', precision='half', gpu_id=0, blip_image_eval_size=512, vit='base'):
+ # vit = 'base' or 'large'
+ model_path = self.get_model_files(model_name)[0]['path']
+ device = torch.device(f"cuda:{gpu_id}")
+ model = blip_decoder(pretrained=model_path,
+ med_config="configs/blip/med_config.json",
+ image_size=blip_image_eval_size, vit=vit)
+ model = model.eval()
+ model = (model if precision=='full' else model.half()).to(device)
+ return {'model': model, 'device': device}
+
+ def load_open_clip(self, model_name='', precision='half', gpu_id=0):
+ pretrained = self.get_model(model_name)['pretrained_name']
+ device = torch.device(f"cuda:{gpu_id}")
+ model, _, preprocesses = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, cache_dir='models/clip')
+ model = model.eval()
+ model = (model if precision=='full' else model.half()).to(device)
+ return {'model': model, 'device': device, 'preprocesses': preprocesses}
+
+ def load_clip(self, model_name='', precision='half', gpu_id=0):
+ device = torch.device(f"cuda:{gpu_id}")
+ model, preprocesses = clip.load(model_name, device=device, download_root='models/clip')
+ model = model.eval()
+ model = (model if precision=='full' else model.half()).to(device)
+ return {'model': model, 'device': device, 'preprocesses': preprocesses}
+
+ def load_model(self, model_name='', precision='half', gpu_id=0):
+ if model_name not in self.available_models:
+ return False
+ if self.models[model_name]['type'] == 'ckpt':
+ self.loaded_models[model_name] = self.load_ckpt(model_name, precision, gpu_id)
+ return True
+ elif self.models[model_name]['type'] == 'realesrgan':
+ self.loaded_models[model_name] = self.load_realesrgan(model_name, precision, gpu_id)
+ return True
+ elif self.models[model_name]['type'] == 'gfpgan':
+ self.loaded_models[model_name] = self.load_gfpgan(model_name, gpu_id)
+ return True
+ elif self.models[model_name]['type'] == 'blip':
+ self.loaded_models[model_name] = self.load_blip(model_name, precision, gpu_id, 512, 'base')
+ return True
+ elif self.models[model_name]['type'] == 'open_clip':
+ self.loaded_models[model_name] = self.load_open_clip(model_name, precision, gpu_id)
+ return True
+ elif self.models[model_name]['type'] == 'clip':
+ self.loaded_models[model_name] = self.load_clip(model_name, precision, gpu_id)
+ return True
+ else:
+ return False
+
+ def validate_model(self, model_name):
+ files = self.get_model_files(model_name)
+ all_ok = True
+ for file_details in files:
+ if not self.check_file_available(file_details['path']):
+ return False
+ if not self.validate_file(file_details):
+ return False
+ return True
+
+ def validate_file(self, file_details):
+ if 'md5sum' in file_details:
+ file_name = file_details['path']
+ logger.debug(f"Getting md5sum of {file_name}")
+ with open(file_name, 'rb') as file_to_check:
+ file_hash = hashlib.md5()
+ while chunk := file_to_check.read(8192):
+ file_hash.update(chunk)
+ if file_details['md5sum'] != file_hash.hexdigest():
+ return False
+ return True
+
+ def check_file_available(self, file_path):
+ return os.path.exists(file_path)
+
+ def check_available(self, files):
+ available = True
+ for file in files:
+ if not self.check_file_available(file['path']):
+ available = False
+ return available
+
+ def download_file(self, url, file_path):
+ # make directory
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ pbar_desc = file_path.split('/')[-1]
+ r = requests.get(url, stream=True)
+ with open(file_path, 'wb') as f:
+ with tqdm(
+ # all optional kwargs
+ unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
+ desc=pbar_desc, total=int(r.headers.get('content-length', 0))
+ ) as pbar:
+ for chunk in r.iter_content(chunk_size=16*1024):
+ if chunk:
+ f.write(chunk)
+ pbar.update(len(chunk))
+
+ def download_model(self, model_name):
+ if model_name in self.available_models:
+ logger.info(f"{model_name} is already available.")
+ return True
+ download = self.get_model_download(model_name)
+ files = self.get_model_files(model_name)
+ for i in range(len(download)):
+ file_path = f"{download[i]['file_path']}/{download[i]['file_name']}" if 'file_path' in download[i] else files[i]['path']
+
+ if 'file_url' in download[i]:
+ download_url = download[i]['file_url']
+ if 'hf_auth' in download[i]:
+ username = self.hf_auth['username']
+ password = self.hf_auth['password']
+ download_url = download_url.format(username=username, password=password)
+ if 'file_name' in download[i]:
+ download_name = download[i]['file_name']
+ if 'file_path' in download[i]:
+ download_path = download[i]['file_path']
+
+ if 'manual' in download[i]:
+ logger.warning(f"The model {model_name} requires manual download from {download_url}. Please place it in {download_path}/{download_name} then press ENTER to continue...")
+ input('')
+ continue
+ # TODO: simplify
+ if "file_content" in download[i]:
+ file_content = download[i]['file_content']
+ logger.info(f"writing {file_content} to {file_path}")
+ # make directory download_path
+ os.makedirs(download_path, exist_ok=True)
+ # write file_content to download_path/download_name
+ with open(os.path.join(download_path, download_name), 'w') as f:
+ f.write(file_content)
+ elif 'symlink' in download[i]:
+ logger.info(f"symlink {file_path} to {download[i]['symlink']}")
+ symlink = download[i]['symlink']
+ # make directory symlink
+ os.makedirs(download_path, exist_ok=True)
+ # make symlink from download_path/download_name to symlink
+ os.symlink(symlink, os.path.join(download_path, download_name))
+ elif 'git' in download[i]:
+ logger.info(f"git clone {download_url} to {file_path}")
+ # make directory download_path
+ os.makedirs(file_path, exist_ok=True)
+ git.Git(file_path).clone(download_url)
+ if 'post_process' in download[i]:
+ for post_process in download[i]['post_process']:
+ if 'delete' in post_process:
+ # delete folder post_process['delete']
+ logger.info(f"delete {post_process['delete']}")
+ try:
+ shutil.rmtree(post_process['delete'])
+ except PermissionError as e:
+ logger.error(f"[!] Something went wrong while deleting the `{post_process['delete']}`. Please delete it manually.")
+ logger.error("PermissionError: ", e)
+ else:
+ if not self.check_file_available(file_path) or model_name in self.tainted_models:
+ logger.debug(f'Downloading {download_url} to {file_path}')
+ self.download_file(download_url, file_path)
+ if not self.validate_model(model_name):
+ return False
+ if model_name in self.tainted_models:
+ self.tainted_models.remove(model_name)
+ self.init()
+ return True
+
+ def download_dependency(self, dependency_name):
+ if dependency_name in self.available_dependencies:
+ logger.info(f"{dependency_name} is already installed.")
+ return True
+ download = self.get_dependency_download(dependency_name)
+ files = self.get_dependency_files(dependency_name)
+ for i in range(len(download)):
+ if "git" in download[i]:
+ logger.warning("git download not implemented yet")
+ break
+
+ file_path = files[i]['path']
+ if 'file_url' in download[i]:
+ download_url = download[i]['file_url']
+ if 'file_name' in download[i]:
+ download_name = download[i]['file_name']
+ if 'file_path' in download[i]:
+ download_path = download[i]['file_path']
+ logger.debug(download_name)
+ if "unzip" in download[i]:
+ zip_path = f'temp/{download_name}.zip'
+ # os dirname zip_path
+ # mkdir temp
+ os.makedirs("temp", exist_ok=True)
+
+ self.download_file(download_url, zip_path)
+ logger.info(f"unzip {zip_path}")
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+ zip_ref.extractall('temp/')
+ # move temp/sd-concepts-library-main/sd-concepts-library to download_path
+ logger.info(f"move temp/{download_name}-main/{download_name} to {download_path}")
+ shutil.move(f"temp/{download_name}-main/{download_name}", download_path)
+ logger.info(f"delete {zip_path}")
+ os.remove(zip_path)
+ logger.info(f"delete temp/{download_name}-main/")
+ shutil.rmtree(f"temp/{download_name}-main")
+ else:
+ if not self.check_file_available(file_path):
+ logger.init(f'{file_path}', status="Downloading")
+ self.download_file(download_url, file_path)
+ self.init()
+ return True
+
+ def download_all_models(self):
+ for model in self.get_filtered_model_names(download_all = True):
+ if not self.check_model_available(model):
+ logger.init(f"{model}", status="Downloading")
+ self.download_model(model)
+ else:
+ logger.info(f"{model} is already downloaded.")
+ return True
+
+ def download_all_dependencies(self):
+ for dependency in self.dependencies:
+ if not self.check_dependency_available(dependency):
+ logger.init(f"{dependency}",status="Downloading")
+ self.download_dependency(dependency)
+ else:
+ logger.info(f"{dependency} is already installed.")
+ return True
+
+ def download_all(self):
+ self.download_all_dependencies()
+ self.download_all_models()
+ return True
+
+ def check_all_available(self):
+ for model in self.models:
+ if not self.check_available(self.get_model_files(model)):
+ return False
+ for dependency in self.dependencies:
+ if not self.check_available(self.get_dependency_files(dependency)):
+ return False
+ return True
+
+ def check_model_available(self, model_name):
+ if model_name not in self.models:
+ return False
+ return self.check_available(self.get_model_files(model_name))
+
+ def check_dependency_available(self, dependency_name):
+ if dependency_name not in self.dependencies:
+ return False
+ return self.check_available(self.get_dependency_files(dependency_name))
+
+ def check_all_available(self):
+ for model in self.models:
+ if not self.check_model_available(model):
+ return False
+ for dependency in self.dependencies:
+ if not self.check_dependency_available(dependency):
+ return False
+ return True
+
+
+
+
+
+
+
+
diff --git a/scripts/nataili/postprocess/__init__.py b/scripts/nataili/postprocess/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/scripts/nataili/postprocess/upscaler.py b/scripts/nataili/postprocess/upscaler.py
new file mode 100644
index 0000000..529b438
--- /dev/null
+++ b/scripts/nataili/postprocess/upscaler.py
@@ -0,0 +1,48 @@
+# Class realesrgan
+# Inputs:
+# - model
+# - device
+# - output_dir
+# - output_ext
+# outupts:
+# - output_images
+import PIL
+from torchvision import transforms
+import numpy as np
+import os
+import cv2
+
+from nataili.util.save_sample import save_sample
+
+class realesrgan:
+ def __init__(self, model, device, output_dir, output_ext='jpg'):
+ self.model = model
+ self.device = device
+ self.output_dir = output_dir
+ self.output_ext = output_ext
+ self.output_images = []
+
+ def generate(self, input_image):
+ # load image
+ img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
+ if len(img.shape) == 3 and img.shape[2] == 4:
+ img_mode = 'RGBA'
+ else:
+ img_mode = None
+ # upscale
+ output, _ = self.model.enhance(img)
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
+ self.output_ext = 'png'
+
+ esrgan_sample = output[:,:,::-1]
+ esrgan_image = PIL.Image.fromarray(esrgan_sample)
+ # append model name to output image name
+ filename = os.path.basename(input_image)
+ filename = os.path.splitext(filename)[0]
+ filename = f'{filename}_esrgan'
+ filename_with_ext = f'{filename}.{self.output_ext}'
+ output_image = os.path.join(self.output_dir, filename_with_ext)
+ save_sample(esrgan_image, filename, self.output_dir, self.output_ext)
+ self.output_images.append(output_image)
+ return
+
diff --git a/scripts/nataili/upscalers/__init__.py b/scripts/nataili/upscalers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/scripts/nataili/upscalers/realesrgan.py b/scripts/nataili/upscalers/realesrgan.py
new file mode 100644
index 0000000..529b438
--- /dev/null
+++ b/scripts/nataili/upscalers/realesrgan.py
@@ -0,0 +1,48 @@
+# Class realesrgan
+# Inputs:
+# - model
+# - device
+# - output_dir
+# - output_ext
+# outupts:
+# - output_images
+import PIL
+from torchvision import transforms
+import numpy as np
+import os
+import cv2
+
+from nataili.util.save_sample import save_sample
+
+class realesrgan:
+ def __init__(self, model, device, output_dir, output_ext='jpg'):
+ self.model = model
+ self.device = device
+ self.output_dir = output_dir
+ self.output_ext = output_ext
+ self.output_images = []
+
+ def generate(self, input_image):
+ # load image
+ img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
+ if len(img.shape) == 3 and img.shape[2] == 4:
+ img_mode = 'RGBA'
+ else:
+ img_mode = None
+ # upscale
+ output, _ = self.model.enhance(img)
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
+ self.output_ext = 'png'
+
+ esrgan_sample = output[:,:,::-1]
+ esrgan_image = PIL.Image.fromarray(esrgan_sample)
+ # append model name to output image name
+ filename = os.path.basename(input_image)
+ filename = os.path.splitext(filename)[0]
+ filename = f'{filename}_esrgan'
+ filename_with_ext = f'{filename}.{self.output_ext}'
+ output_image = os.path.join(self.output_dir, filename_with_ext)
+ save_sample(esrgan_image, filename, self.output_dir, self.output_ext)
+ self.output_images.append(output_image)
+ return
+
diff --git a/scripts/nataili/util/__init__.py b/scripts/nataili/util/__init__.py
new file mode 100644
index 0000000..766cb27
--- /dev/null
+++ b/scripts/nataili/util/__init__.py
@@ -0,0 +1 @@
+from nataili.util.logger import logger,set_logger_verbosity, quiesce_logger, test_logger
diff --git a/scripts/nataili/util/cache.py b/scripts/nataili/util/cache.py
new file mode 100644
index 0000000..ef7711c
--- /dev/null
+++ b/scripts/nataili/util/cache.py
@@ -0,0 +1,16 @@
+import gc
+
+import torch
+import threading
+import pynvml
+import time
+
+with torch.no_grad():
+ def torch_gc():
+ for _ in range(2):
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+ torch.cuda.synchronize()
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.reset_accumulated_memory_stats()
diff --git a/scripts/nataili/util/check_prompt_length.py b/scripts/nataili/util/check_prompt_length.py
new file mode 100644
index 0000000..4953eac
--- /dev/null
+++ b/scripts/nataili/util/check_prompt_length.py
@@ -0,0 +1,18 @@
+def check_prompt_length(model, prompt, comments):
+ """this function tests if prompt is too long, and if so, adds a message to comments"""
+
+ tokenizer = model.cond_stage_model.tokenizer
+ max_length = model.cond_stage_model.max_length
+
+ info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
+ return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
+ ovf = info['overflowing_tokens'][0]
+ overflowing_count = ovf.shape[0]
+ if overflowing_count == 0:
+ return
+
+ vocab = {v: k for k, v in tokenizer.get_vocab().items()}
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+ del tokenizer
\ No newline at end of file
diff --git a/scripts/nataili/util/get_next_sequence_number.py b/scripts/nataili/util/get_next_sequence_number.py
new file mode 100644
index 0000000..ede028f
--- /dev/null
+++ b/scripts/nataili/util/get_next_sequence_number.py
@@ -0,0 +1,22 @@
+from pathlib import Path
+
+def get_next_sequence_number(path, prefix=''):
+ """
+ Determines and returns the next sequence number to use when saving an
+ image in the specified directory.
+
+ If a prefix is given, only consider files whose names start with that
+ prefix, and strip the prefix from filenames before extracting their
+ sequence number.
+
+ The sequence starts at 0.
+ """
+ result = -1
+ for p in Path(path).iterdir():
+ if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix):
+ tmp = p.name[len(prefix):]
+ try:
+ result = max(int(tmp.split('-')[0]), result)
+ except ValueError:
+ pass
+ return result + 1
\ No newline at end of file
diff --git a/scripts/nataili/util/image_grid.py b/scripts/nataili/util/image_grid.py
new file mode 100644
index 0000000..7ea85eb
--- /dev/null
+++ b/scripts/nataili/util/image_grid.py
@@ -0,0 +1,21 @@
+import math
+
+import PIL
+
+
+def image_grid(imgs, n_rows=None):
+ if n_rows is not None:
+ rows = n_rows
+ else:
+ rows = math.sqrt(len(imgs))
+ rows = round(rows)
+
+ cols = math.ceil(len(imgs) / rows)
+
+ w, h = imgs[0].size
+ grid = PIL.Image.new('RGB', size=(cols * w, rows * h), color='black')
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+
+ return grid
diff --git a/scripts/nataili/util/load_learned_embed_in_clip.py b/scripts/nataili/util/load_learned_embed_in_clip.py
new file mode 100644
index 0000000..9507e58
--- /dev/null
+++ b/scripts/nataili/util/load_learned_embed_in_clip.py
@@ -0,0 +1,40 @@
+import os
+
+import torch
+
+
+def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
+ # separate token and the embeds
+ if learned_embeds_path.endswith('.pt'):
+ # old format
+ # token = * so replace with file directory name when converting
+ trained_token = os.path.basename(learned_embeds_path)
+ params_dict = {
+ trained_token: torch.tensor(list(loaded_learned_embeds['string_to_param'].items())[0][1])
+ }
+ learned_embeds_path = os.path.splitext(learned_embeds_path)[0] + '.bin'
+ torch.save(params_dict, learned_embeds_path)
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
+ trained_token = list(loaded_learned_embeds.keys())[0]
+ embeds = loaded_learned_embeds[trained_token]
+ elif learned_embeds_path.endswith('.bin'):
+ trained_token = list(loaded_learned_embeds.keys())[0]
+ embeds = loaded_learned_embeds[trained_token]
+
+ embeds = loaded_learned_embeds[trained_token]
+ # cast to dtype of text_encoder
+ dtype = text_encoder.get_input_embeddings().weight.dtype
+ embeds.to(dtype)
+
+ # add the token in tokenizer
+ token = token if token is not None else trained_token
+ num_added_tokens = tokenizer.add_tokens(token)
+
+ # resize the token embeddings
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # get the id for the token and assign the embeds
+ token_id = tokenizer.convert_tokens_to_ids(token)
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
+ return token
diff --git a/scripts/nataili/util/logger.py b/scripts/nataili/util/logger.py
new file mode 100644
index 0000000..d012c83
--- /dev/null
+++ b/scripts/nataili/util/logger.py
@@ -0,0 +1,102 @@
+import sys
+from functools import partialmethod
+from loguru import logger
+
+STDOUT_LEVELS = ["GENERATION", "PROMPT"]
+INIT_LEVELS = ["INIT", "INIT_OK", "INIT_WARN", "INIT_ERR"]
+MESSAGE_LEVELS = ["MESSAGE"]
+# By default we're at error level or higher
+verbosity = 20
+quiet = 0
+
+def set_logger_verbosity(count):
+ global verbosity
+ # The count comes reversed. So count = 0 means minimum verbosity
+ # While count 5 means maximum verbosity
+ # So the more count we have, the lowe we drop the versbosity maximum
+ verbosity = 20 - (count * 10)
+
+def quiesce_logger(count):
+ global quiet
+ # The bigger the count, the more silent we want our logger
+ quiet = count * 10
+
+def is_stdout_log(record):
+ if record["level"].name not in STDOUT_LEVELS:
+ return(False)
+ if record["level"].no < verbosity + quiet:
+ return(False)
+ return(True)
+
+def is_init_log(record):
+ if record["level"].name not in INIT_LEVELS:
+ return(False)
+ if record["level"].no < verbosity + quiet:
+ return(False)
+ return(True)
+
+def is_msg_log(record):
+ if record["level"].name not in MESSAGE_LEVELS:
+ return(False)
+ if record["level"].no < verbosity + quiet:
+ return(False)
+ return(True)
+
+def is_stderr_log(record):
+ if record["level"].name in STDOUT_LEVELS + INIT_LEVELS + MESSAGE_LEVELS:
+ return(False)
+ if record["level"].no < verbosity + quiet:
+ return(False)
+ return(True)
+
+def test_logger():
+ logger.generation("This is a generation message\nIt is typically multiline\nThee Lines".encode("unicode_escape").decode("utf-8"))
+ logger.prompt("This is a prompt message")
+ logger.debug("Debug Message")
+ logger.info("Info Message")
+ logger.warning("Info Warning")
+ logger.error("Error Message")
+ logger.critical("Critical Message")
+ logger.init("This is an init message", status="Starting")
+ logger.init_ok("This is an init message", status="OK")
+ logger.init_warn("This is an init message", status="Warning")
+ logger.init_err("This is an init message", status="Error")
+ logger.message("This is user message")
+ sys.exit()
+
+
+logfmt = "{level: <10} | {time:YYYY-MM-DD HH:mm:ss} | {name}:{function}:{line} - {message}"
+genfmt = "{level: <10} @ {time:YYYY-MM-DD HH:mm:ss} | {message}"
+initfmt = "INIT | {extra[status]: <11} | {message}"
+msgfmt = "{level: <10} | {message}"
+
+try:
+ logger.level("GENERATION", no=24, color="")
+ logger.level("PROMPT", no=23, color="")
+ logger.level("INIT", no=31, color="")
+ logger.level("INIT_OK", no=31, color="")
+ logger.level("INIT_WARN", no=31, color="")
+ logger.level("INIT_ERR", no=31, color="")
+ # Messages contain important information without which this application might not be able to be used
+ # As such, they have the highest priority
+ logger.level("MESSAGE", no=61, color="")
+except TypeError:
+ pass
+
+logger.__class__.generation = partialmethod(logger.__class__.log, "GENERATION")
+logger.__class__.prompt = partialmethod(logger.__class__.log, "PROMPT")
+logger.__class__.init = partialmethod(logger.__class__.log, "INIT")
+logger.__class__.init_ok = partialmethod(logger.__class__.log, "INIT_OK")
+logger.__class__.init_warn = partialmethod(logger.__class__.log, "INIT_WARN")
+logger.__class__.init_err = partialmethod(logger.__class__.log, "INIT_ERR")
+logger.__class__.message = partialmethod(logger.__class__.log, "MESSAGE")
+
+config = {
+ "handlers": [
+ {"sink": sys.stderr, "format": logfmt, "colorize":True, "filter": is_stderr_log},
+ {"sink": sys.stdout, "format": genfmt, "level": "PROMPT", "colorize":True, "filter": is_stdout_log},
+ {"sink": sys.stdout, "format": initfmt, "level": "INIT", "colorize":True, "filter": is_init_log},
+ {"sink": sys.stdout, "format": msgfmt, "level": "MESSAGE", "colorize":True, "filter": is_msg_log}
+ ],
+}
+logger.configure(**config)
diff --git a/scripts/nataili/util/save_sample.py b/scripts/nataili/util/save_sample.py
new file mode 100644
index 0000000..5c791d3
--- /dev/null
+++ b/scripts/nataili/util/save_sample.py
@@ -0,0 +1,20 @@
+import os
+
+def save_sample(image, filename, sample_path, extension='png', jpg_quality=95, webp_quality=95, webp_lossless=True, png_compression=9):
+ path = os.path.join(sample_path, filename + '.' + extension)
+ if os.path.exists(path):
+ return False
+ if not os.path.exists(sample_path):
+ os.makedirs(sample_path)
+ if extension == 'png':
+ image.save(path, format='PNG', compress_level=png_compression)
+ elif extension == 'jpg':
+ image.save(path, quality=jpg_quality, optimize=True)
+ elif extension == 'webp':
+ image.save(path, quality=webp_quality, lossless=webp_lossless)
+ else:
+ return False
+ if os.path.exists(path):
+ return True
+ else:
+ return False
diff --git a/scripts/nataili/util/seed_to_int.py b/scripts/nataili/util/seed_to_int.py
new file mode 100644
index 0000000..61cc9fd
--- /dev/null
+++ b/scripts/nataili/util/seed_to_int.py
@@ -0,0 +1,22 @@
+import random
+
+def seed_to_int(s):
+ if type(s) is int:
+ return s
+ if s is None or s == '':
+ return random.randint(0, 2**32 - 1)
+
+ if type(s) is list:
+ seed_list = []
+ for seed in s:
+ if seed is None or seed == '':
+ seed_list.append(random.randint(0, 2**32 - 1))
+ else:
+ seed_list = s
+
+ return seed_list
+
+ n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1))
+ while n >= 2**32:
+ n = n >> 32
+ return n
\ No newline at end of file
diff --git a/streamlit_webview.py b/streamlit_webview.py
new file mode 100644
index 0000000..9853c4e
--- /dev/null
+++ b/streamlit_webview.py
@@ -0,0 +1,15 @@
+import os, webview
+from streamlit.web import bootstrap
+from streamlit import config as _config
+
+webview.create_window('Sygil', 'http://localhost:8501', width=1000, height=800, min_size=(500, 500))
+webview.start()
+
+dirname = os.path.dirname(__file__)
+filename = os.path.join(dirname, 'scripts/webui_streamlit.py')
+
+_config.set_option("server.headless", True)
+args = []
+
+#streamlit.cli.main_run(filename, args)
+bootstrap.run(filename,'',args, flag_options={})
\ No newline at end of file
diff --git a/webui.cmd b/webui.cmd
index e1248a9..f54fe8b 100644
--- a/webui.cmd
+++ b/webui.cmd
@@ -1,17 +1,17 @@
@echo off
:: This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
-
+::
:: Copyright 2022 Sygil-Dev team.
:: This program is free software: you can redistribute it and/or modify
:: it under the terms of the GNU Affero General Public License as published by
:: the Free Software Foundation, either version 3 of the License, or
:: (at your option) any later version.
-
+::
:: This program is distributed in the hope that it will be useful,
:: but WITHOUT ANY WARRANTY; without even the implied warranty of
:: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
:: GNU Affero General Public License for more details.
-
+::
:: You should have received a copy of the GNU Affero General Public License
:: along with this program. If not, see .
:: Run all commands using this script's directory as the working directory
@@ -102,12 +102,11 @@ call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%"
:PROMPT
set SETUPTOOLS_USE_DISTUTILS=stdlib
-IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" (
- set "PYTHONPATH=%~dp0"
- python scripts\relauncher.py %*
+IF EXIST "models\ldm\stable-diffusion-v1\Stable Diffusion v1.5.ckpt" (
+ python -m streamlit run scripts\webui_streamlit.py --theme.base dark --server.address localhost
) ELSE (
- echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'.
- pause
+ echo Your model file does not exist! Once the WebUI launches please visit the Model Manager page and download the models by using the Download button for each model.
+ python -m streamlit run scripts\webui_streamlit.py --theme.base dark --server.address localhost
)
::cmd /k
diff --git a/webui-streamlit.cmd b/webui_legacy.cmd
similarity index 97%
rename from webui-streamlit.cmd
rename to webui_legacy.cmd
index 45414be..654e21c 100644
--- a/webui-streamlit.cmd
+++ b/webui_legacy.cmd
@@ -1,17 +1,17 @@
@echo off
:: This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
-::
+
:: Copyright 2022 Sygil-Dev team.
:: This program is free software: you can redistribute it and/or modify
:: it under the terms of the GNU Affero General Public License as published by
:: the Free Software Foundation, either version 3 of the License, or
:: (at your option) any later version.
-::
+
:: This program is distributed in the hope that it will be useful,
:: but WITHOUT ANY WARRANTY; without even the implied warranty of
:: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
:: GNU Affero General Public License for more details.
-::
+
:: You should have received a copy of the GNU Affero General Public License
:: along with this program. If not, see .
:: Run all commands using this script's directory as the working directory
@@ -99,7 +99,8 @@ call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%"
:PROMPT
set SETUPTOOLS_USE_DISTUTILS=stdlib
IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" (
- python -m streamlit run scripts\webui_streamlit.py --theme.base dark --server.address localhost
+ set "PYTHONPATH=%~dp0"
+ python scripts\relauncher.py %*
) ELSE (
echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'.
pause