Fixed the serverAddress default value stopping people from using the UI in services like runpod. (#1621)

This commit is contained in:
Alejandro Gil 2022-10-29 09:56:50 -07:00 committed by GitHub
parent 1edaccfeb5
commit c5981694c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 2 additions and 1577 deletions

View File

@ -36,7 +36,6 @@ enableWebsocketCompression = false
[browser]
gatherUsageStats = false
serverPort = 8501
serverAddress = "localhost"
[mapbox]
token = ""

View File

@ -251,7 +251,7 @@ def layout():
if not st.session_state['defaults'].admin.hide_browser_setting:
with st.expander("Browser", expanded=True):
st.session_state["streamlit_config"]['browser']['serverAddress'] = st.text_input("Server Address",
value=st.session_state["streamlit_config"]['browser']['serverAddress'],
value=st.session_state["streamlit_config"]['browser']['serverAddress'] if "serverAddress" in st.session_state["streamlit_config"] else "localhost",
help="Internet address where users should point their browsers in order \
to connect to the app. Can be IP address or DNS name and path.\
This is used to: - Set the correct URL for CORS and XSRF protection purposes. \

View File

@ -1,551 +0,0 @@
import os
import re
import sys
import k_diffusion as K
import tqdm
from contextlib import contextmanager, nullcontext
import skimage
import numpy as np
import PIL
import torch
from einops import rearrange
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.kdiffusion import CFGMaskedDenoiser, KDiffusionSampler
from ldm.models.diffusion.plms import PLMSSampler
from nataili.util.cache import torch_gc
from nataili.util.check_prompt_length import check_prompt_length
from nataili.util.get_next_sequence_number import get_next_sequence_number
from nataili.util.image_grid import image_grid
from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip
from nataili.util.save_sample import save_sample
from nataili.util.seed_to_int import seed_to_int
from slugify import slugify
import PIL
class img2img:
def __init__(self, model, device, output_dir, save_extension='jpg',
output_file_path=False, load_concepts=False, concepts_dir=None,
verify_input=True, auto_cast=True):
self.model = model
self.output_dir = output_dir
self.output_file_path = output_file_path
self.save_extension = save_extension
self.load_concepts = load_concepts
self.concepts_dir = concepts_dir
self.verify_input = verify_input
self.auto_cast = auto_cast
self.device = device
self.comments = []
self.output_images = []
self.info = ''
self.stats = ''
self.images = []
def create_random_tensors(self, shape, seeds):
xs = []
for seed in seeds:
torch.manual_seed(seed)
# randn results depend on device; gpu and cpu get different results for same seed;
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
# but the original script had it like this so i do not dare change it for now because
# it will break everyone's seeds.
xs.append(torch.randn(shape, device=self.device))
x = torch.stack(xs)
return x
def process_prompt_tokens(self, prompt_tokens):
# compviz codebase
tokenizer = self.model.cond_stage_model.tokenizer
text_encoder = self.model.cond_stage_model.transformer
# diffusers codebase
#tokenizer = pipe.tokenizer
#text_encoder = pipe.text_encoder
ext = ('.pt', '.bin')
for token_name in prompt_tokens:
embedding_path = os.path.join(self.concepts_dir, token_name)
if os.path.exists(embedding_path):
for files in os.listdir(embedding_path):
if files.endswith(ext):
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
else:
print(f"Concept {token_name} not found in {self.concepts_dir}")
del tokenizer, text_encoder
return
del tokenizer, text_encoder
def resize_image(self, resize_mode, im, width, height):
LANCZOS = (PIL.Image.Resampling.LANCZOS if hasattr(PIL.Image, 'Resampling') else PIL.Image.LANCZOS)
if resize_mode == "resize":
res = im.resize((width, height), resample=LANCZOS)
elif resize_mode == "crop":
ratio = width / height
src_ratio = im.width / im.height
src_w = width if ratio > src_ratio else im.width * height // im.height
src_h = height if ratio <= src_ratio else im.height * width // im.width
resized = im.resize((src_w, src_h), resample=LANCZOS)
res = PIL.Image.new("RGBA", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
else:
ratio = width / height
src_ratio = im.width / im.height
src_w = width if ratio < src_ratio else im.width * height // im.height
src_h = height if ratio >= src_ratio else im.height * width // im.width
resized = im.resize((src_w, src_h), resample=LANCZOS)
res = PIL.Image.new("RGBA", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
return res
#
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(self, data):
if data.ndim > 2: # has channels
out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
for c in range(data.shape[2]):
c_data = data[:,:,c]
out_fft[:,:,c] = np.fft.fft2(np.fft.fftshift(c_data),norm="ortho")
out_fft[:,:,c] = np.fft.ifftshift(out_fft[:,:,c])
else: # one channel
out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
out_fft[:,:] = np.fft.fft2(np.fft.fftshift(data),norm="ortho")
out_fft[:,:] = np.fft.ifftshift(out_fft[:,:])
return out_fft
def _ifft2(self, data):
if data.ndim > 2: # has channels
out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
for c in range(data.shape[2]):
c_data = data[:,:,c]
out_ifft[:,:,c] = np.fft.ifft2(np.fft.fftshift(c_data),norm="ortho")
out_ifft[:,:,c] = np.fft.ifftshift(out_ifft[:,:,c])
else: # one channel
out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
out_ifft[:,:] = np.fft.ifft2(np.fft.fftshift(data),norm="ortho")
out_ifft[:,:] = np.fft.ifftshift(out_ifft[:,:])
return out_ifft
def _get_gaussian_window(self, width, height, std=3.14, mode=0):
window_scale_x = float(width / min(width, height))
window_scale_y = float(height / min(width, height))
window = np.zeros((width, height))
x = (np.arange(width) / width * 2. - 1.) * window_scale_x
for y in range(height):
fy = (y / height * 2. - 1.) * window_scale_y
if mode == 0:
window[:, y] = np.exp(-(x**2+fy**2) * std)
else:
window[:, y] = (1/((x**2+1.) * (fy**2+1.))) ** (std/3.14) # hey wait a minute that's not gaussian
return window
def _get_masked_window_rgb(self, np_mask_grey, hardness=1.):
np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
if hardness != 1.:
hardened = np_mask_grey[:] ** hardness
else:
hardened = np_mask_grey[:]
for c in range(3):
np_mask_rgb[:,:,c] = hardened[:]
return np_mask_rgb
def get_matched_noise(self, _np_src_image, np_mask_rgb, noise_q, color_variation):
"""
Explanation:
Getting good results in/out-painting with stable diffusion can be challenging.
Although there are simpler effective solutions for in-painting, out-painting can be especially challenging because there is no color data
in the masked area to help prompt the generator. Ideally, even for in-painting we'd like work effectively without that data as well.
Provided here is my take on a potential solution to this problem.
By taking a fourier transform of the masked src img we get a function that tells us the presence and orientation of each feature scale in the unmasked src.
Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations, and positions increases output coherence
by helping keep features aligned. This technique is applicable to any continuous generation task such as audio or video, each of which can
be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased. For multi-channel data such as color
or stereo sound the "color tone" or histogram of the seed noise can be matched to improve quality (using scikit-image currently)
This method is quite robust and has the added benefit of being fast independently of the size of the out-painted area.
The effects of this method include things like helping the generator integrate the pre-existing view distance and camera angle.
Carefully managing color and brightness with histogram matching is also essential to achieving good coherence.
noise_q controls the exponent in the fall-off of the distribution can be any positive number, lower values means higher detail (range > 0, default 1.)
color_variation controls how much freedom is allowed for the colors/palette of the out-painted area (range 0..1, default 0.01)
This code is provided as is under the Unlicense (https://unlicense.org/)
Although you have no obligation to do so, if you found this code helpful please find it in your heart to credit me [parlance-zz].
Questions or comments can be sent to parlance@fifth-harmonic.com (https://github.com/parlance-zz/)
This code is part of a new branch of a discord bot I am working on integrating with diffusers (https://github.com/parlance-zz/g-diffuser-bot)
"""
global DEBUG_MODE
global TMP_ROOT_PATH
width = _np_src_image.shape[0]
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
np_mask_grey = (np.sum(np_mask_rgb, axis=2)/3.)
np_src_grey = (np.sum(np_src_image, axis=2)/3.)
all_mask = np.ones((width, height), dtype=bool)
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
windowed_image = _np_src_image * (1.-self._get_masked_window_rgb(np_mask_grey))
windowed_image /= np.max(windowed_image)
windowed_image += np.average(_np_src_image) * np_mask_rgb# / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
#windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) / (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area
#_save_debug_img(windowed_image, "windowed_src_img")
src_fft = self._fft2(windowed_image) # get feature statistics from masked src img
src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist
#_save_debug_img(src_dist, "windowed_src_dist")
noise_window = self._get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
noise_rgb = np.random.random_sample((width, height, num_channels))
noise_grey = (np.sum(noise_rgb, axis=2)/3.)
noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
for c in range(num_channels):
noise_rgb[:,:,c] += (1. - color_variation) * noise_grey
noise_fft = self._fft2(noise_rgb)
for c in range(num_channels):
noise_fft[:,:,c] *= noise_window
noise_rgb = np.real(self._ifft2(noise_fft))
shaped_noise_fft = self._fft2(noise_rgb)
shaped_noise_fft[:,:,:] = np.absolute(shaped_noise_fft[:,:,:])**2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
brightness_variation = 0.#color_variation # todo: temporarily tieing brightness variation to color variation for now
contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
# scikit-image is used for histogram matching, very convenient!
shaped_noise = np.real(self._ifft2(shaped_noise_fft))
shaped_noise -= np.min(shaped_noise)
shaped_noise /= np.max(shaped_noise)
shaped_noise[img_mask,:] = skimage.exposure.match_histograms(shaped_noise[img_mask,:]**1., contrast_adjusted_np_src[ref_mask,:], channel_axis=1)
shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
#_save_debug_img(shaped_noise, "shaped_noise")
matched_noise = np.zeros((width, height, num_channels))
matched_noise = shaped_noise[:]
#matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:], _np_src_image[ref_mask,:], channel_axis=1)
#matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb
#_save_debug_img(matched_noise, "matched_noise")
"""
todo:
color_variation doesnt have to be a single number, the overall color tone of the out-painted area could be param controlled
"""
return np.clip(matched_noise, 0., 1.)
def find_noise_for_image(self, model, device, init_image, prompt, steps=200, cond_scale=2.0, verbose=False, normalize=False, generation_callback=None):
image = np.array(init_image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2. * image - 1.
image = image.to(device)
x = model.get_first_stage_encoding(model.encode_first_stage(image))
uncond = model.get_learned_conditioning([''])
cond = model.get_learned_conditioning([prompt])
s_in = x.new_ones([x.shape[0]])
dnw = K.external.CompVisDenoiser(model)
sigmas = dnw.get_sigmas(steps).flip(0)
if verbose:
print(sigmas)
for i in tqdm.trange(1, len(sigmas)):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
cond_in = torch.cat([uncond, cond])
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
if i == 1:
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
else:
t = dnw.sigma_to_t(sigma_in)
eps = model.apply_model(x_in * c_in, t, cond=cond_in)
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale
if i == 1:
d = (x - denoised) / (2 * sigmas[i])
else:
d = (x - denoised) / sigmas[i - 1]
dt = sigmas[i] - sigmas[i - 1]
x = x + d * dt
return x / sigmas[-1]
def generate(self, prompt: str, init_img=None, init_mask=None, mask_mode='mask', resize_mode='resize', noise_mode='seed',
denoising_strength:float=0.8, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None,
height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0):
seed = seed_to_int(seed)
image_dict = {
"seed": seed
}
# Init image is assumed to be a PIL image
init_img = self.resize_image('resize', init_img, width, height)
if sampler_name == 'PLMS':
sampler = PLMSSampler(self.model)
elif sampler_name == 'DDIM':
sampler = DDIMSampler(self.model)
elif sampler_name == 'k_dpm_2_a':
sampler = KDiffusionSampler(self.model,'dpm_2_ancestral')
elif sampler_name == 'k_dpm_2':
sampler = KDiffusionSampler(self.model,'dpm_2')
elif sampler_name == 'k_euler_a':
sampler = KDiffusionSampler(self.model,'euler_ancestral')
elif sampler_name == 'k_euler':
sampler = KDiffusionSampler(self.model,'euler')
elif sampler_name == 'k_heun':
sampler = KDiffusionSampler(self.model,'heun')
elif sampler_name == 'k_lms':
sampler = KDiffusionSampler(self.model,'lms')
else:
raise Exception("Unknown sampler: " + sampler_name)
torch_gc()
def process_init_mask(init_mask: PIL.Image):
if init_mask.mode == "RGBA":
init_mask = init_mask.convert('RGBA')
background = PIL.Image.new('RGBA', init_mask.size, (0, 0, 0))
init_mask = PIL.Image.alpha_composite(background, init_mask)
init_mask = init_mask.convert('RGB')
return init_mask
if mask_mode == "mask":
if init_mask:
init_mask = process_init_mask(init_mask)
elif mask_mode == "invert":
if init_mask:
init_mask = process_init_mask(init_mask)
init_mask = PIL.ImageOps.invert(init_mask)
elif mask_mode == "alpha":
init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
init_mask = init_img_transparency
init_mask = init_mask.convert("RGB")
init_mask = self.resize_image(resize_mode, init_mask, width, height)
init_mask = init_mask.convert("RGB")
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(denoising_strength * ddim_steps)
if init_mask is not None and (noise_mode == "matched" or noise_mode == "find_and_matched") and init_img is not None:
noise_q = 0.99
color_variation = 0.0
mask_blend_factor = 1.0
np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
np_mask_rgb = 1. - (np.asarray(PIL.ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
np_mask_rgb -= np.min(np_mask_rgb)
np_mask_rgb /= np.max(np_mask_rgb)
np_mask_rgb = 1. - np_mask_rgb
np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
#np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
#np_mask_rgb = np_mask_rgb + blurred
np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
noise_rgb = self.get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
noised = noise_rgb[:]
blend_mask_rgb **= (2.)
noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
ref_mask = np_mask_grey < 1e-3
all_mask = np.ones((height, width), dtype=bool)
noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
init_img = PIL.Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
def init():
image = init_img.convert('RGB')
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
mask_channel = None
if init_mask:
alpha = self.resize_image(resize_mode, init_mask, width // 8, height // 8)
mask_channel = alpha.split()[-1]
mask = None
if mask_channel is not None:
mask = np.array(mask_channel).astype(np.float32) / 255.0
mask = (1 - mask)
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3)
mask = torch.from_numpy(mask).to(self.model.device)
init_image = 2. * image - 1.
init_image = init_image.to(self.model.device)
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
return init_latent, mask,
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
t_enc_steps = t_enc
obliterate = False
if ddim_steps == t_enc_steps:
t_enc_steps = t_enc_steps - 1
obliterate = True
if sampler_name != 'DDIM':
x0, z_mask = init_data
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
noise = x * sigmas[ddim_steps - t_enc_steps - 1]
xi = x0 + noise
# Obliterate masked image
if z_mask is not None and obliterate:
random = torch.randn(z_mask.shape, device=xi.device)
xi = (z_mask * noise) + ((1-z_mask) * xi)
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
else:
x0, z_mask = init_data
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(self.model.device))
# Obliterate masked image
if z_mask is not None and obliterate:
random = torch.randn(z_mask.shape, device=z_enc.device)
z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
# decode it
samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=unconditional_conditioning,
z_mask=z_mask, x0=x0)
return samples_ddim
torch_gc()
if self.load_concepts and self.concepts_dir is not None:
prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
if prompt_tokens:
self.process_prompt_tokens(prompt_tokens)
os.makedirs(self.output_dir, exist_ok=True)
sample_path = os.path.join(self.output_dir, "samples")
os.makedirs(sample_path, exist_ok=True)
if self.verify_input:
try:
check_prompt_length(self.model, prompt, self.comments)
except:
import traceback
print("Error verifying input:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
all_prompts = batch_size * n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))]
precision_scope = torch.autocast if self.auto_cast else nullcontext
with torch.no_grad(), precision_scope("cuda"):
for n in range(n_iter):
print(f"Iteration: {n+1}/{n_iter}")
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
uc = self.model.get_learned_conditioning(len(prompts) * [''])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = self.model.get_learned_conditioning(prompts)
opt_C = 4
opt_f = 8
shape = [opt_C, height // opt_f, width // opt_f]
x = self.create_random_tensors(shape, seeds=seeds)
init_data = init()
samples_ddim = sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for i, x_sample in enumerate(x_samples_ddim):
sanitized_prompt = slugify(prompts[i])
full_path = os.path.join(os.getcwd(), sample_path)
sample_path_i = sample_path
base_count = get_next_sequence_number(sample_path_i)
filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)]
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
image = PIL.Image.fromarray(x_sample)
image_dict['image'] = image
self.images.append(image_dict)
if save_individual_images:
path = os.path.join(sample_path, filename + '.' + self.save_extension)
success = save_sample(image, filename, sample_path_i, self.save_extension)
if success:
if self.output_file_path:
self.output_images.append(path)
else:
self.output_images.append(image)
else:
return
self.info = f"""
{prompt}
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}
""".strip()
self.stats = f'''
'''
for comment in self.comments:
self.info += "\n\n" + comment
torch_gc()
del sampler
return

View File

@ -1,201 +0,0 @@
import os
import re
import sys
from contextlib import contextmanager, nullcontext
import numpy as np
import PIL
import torch
from einops import rearrange
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.kdiffusion import KDiffusionSampler
from ldm.models.diffusion.plms import PLMSSampler
from nataili.util.cache import torch_gc
from nataili.util.check_prompt_length import check_prompt_length
from nataili.util.get_next_sequence_number import get_next_sequence_number
from nataili.util.image_grid import image_grid
from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip
from nataili.util.save_sample import save_sample
from nataili.util.seed_to_int import seed_to_int
from slugify import slugify
class txt2img:
def __init__(self, model, device, output_dir, save_extension='jpg',
output_file_path=False, load_concepts=False, concepts_dir=None,
verify_input=True, auto_cast=True):
self.model = model
self.output_dir = output_dir
self.output_file_path = output_file_path
self.save_extension = save_extension
self.load_concepts = load_concepts
self.concepts_dir = concepts_dir
self.verify_input = verify_input
self.auto_cast = auto_cast
self.device = device
self.comments = []
self.output_images = []
self.info = ''
self.stats = ''
self.images = []
def create_random_tensors(self, shape, seeds):
xs = []
for seed in seeds:
torch.manual_seed(seed)
# randn results depend on device; gpu and cpu get different results for same seed;
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
# but the original script had it like this so i do not dare change it for now because
# it will break everyone's seeds.
xs.append(torch.randn(shape, device=self.device))
x = torch.stack(xs)
return x
def process_prompt_tokens(self, prompt_tokens):
# compviz codebase
tokenizer = self.model.cond_stage_model.tokenizer
text_encoder = self.model.cond_stage_model.transformer
# diffusers codebase
#tokenizer = pipe.tokenizer
#text_encoder = pipe.text_encoder
ext = ('.pt', '.bin')
for token_name in prompt_tokens:
embedding_path = os.path.join(self.concepts_dir, token_name)
if os.path.exists(embedding_path):
for files in os.listdir(embedding_path):
if files.endswith(ext):
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
else:
print(f"Concept {token_name} not found in {self.concepts_dir}")
del tokenizer, text_encoder
return
del tokenizer, text_encoder
def generate(self, prompt: str, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None,
height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0):
seed = seed_to_int(seed)
image_dict = {
"seed": seed
}
negprompt = ''
if '###' in prompt:
prompt, negprompt = prompt.split('###', 1)
prompt = prompt.strip()
negprompt = negprompt.strip()
if sampler_name == 'PLMS':
sampler = PLMSSampler(self.model)
elif sampler_name == 'DDIM':
sampler = DDIMSampler(self.model)
elif sampler_name == 'k_dpm_2_a':
sampler = KDiffusionSampler(self.model,'dpm_2_ancestral')
elif sampler_name == 'k_dpm_2':
sampler = KDiffusionSampler(self.model,'dpm_2')
elif sampler_name == 'k_euler_a':
sampler = KDiffusionSampler(self.model,'euler_ancestral')
elif sampler_name == 'k_euler':
sampler = KDiffusionSampler(self.model,'euler')
elif sampler_name == 'k_heun':
sampler = KDiffusionSampler(self.model,'heun')
elif sampler_name == 'k_lms':
sampler = KDiffusionSampler(self.model,'lms')
else:
raise Exception("Unknown sampler: " + sampler_name)
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=unconditional_conditioning, x_T=x)
return samples_ddim
torch_gc()
if self.load_concepts and self.concepts_dir is not None:
prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
if prompt_tokens:
self.process_prompt_tokens(prompt_tokens)
os.makedirs(self.output_dir, exist_ok=True)
sample_path = os.path.join(self.output_dir, "samples")
os.makedirs(sample_path, exist_ok=True)
if self.verify_input:
try:
check_prompt_length(self.model, prompt, self.comments)
except:
import traceback
print("Error verifying input:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
all_prompts = batch_size * n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))]
precision_scope = torch.autocast if self.auto_cast else nullcontext
with torch.no_grad(), precision_scope("cuda"):
for n in range(n_iter):
print(f"Iteration: {n+1}/{n_iter}")
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
uc = self.model.get_learned_conditioning(len(prompts) * [negprompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = self.model.get_learned_conditioning(prompts)
opt_C = 4
opt_f = 8
shape = [opt_C, height // opt_f, width // opt_f]
x = self.create_random_tensors(shape, seeds=seeds)
samples_ddim = sample(init_data=None, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for i, x_sample in enumerate(x_samples_ddim):
sanitized_prompt = slugify(prompts[i])
full_path = os.path.join(os.getcwd(), sample_path)
sample_path_i = sample_path
base_count = get_next_sequence_number(sample_path_i)
filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)]
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
image = PIL.Image.fromarray(x_sample)
image_dict['image'] = image
self.images.append(image_dict)
if save_individual_images:
path = os.path.join(sample_path, filename + '.' + self.save_extension)
success = save_sample(image, filename, sample_path_i, self.save_extension)
if success:
if self.output_file_path:
self.output_images.append(path)
else:
self.output_images.append(image)
else:
return
self.info = f"""
{prompt}
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}
""".strip()
self.stats = f'''
'''
for comment in self.comments:
self.info += "\n\n" + comment
torch_gc()
del sampler
return

View File

@ -1,458 +0,0 @@
import os
import json
import shutil
import zipfile
import requests
import git
import torch
import hashlib
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from transformers import logging
from basicsr.archs.rrdbnet_arch import RRDBNet
from gfpgan import GFPGANer
from realesrgan import RealESRGANer
from ldm.models.blip import blip_decoder
from tqdm import tqdm
import open_clip
import clip
from nataili.util.cache import torch_gc
from nataili.util import logger
logging.set_verbosity_error()
models = json.load(open('./db.json'))
dependencies = json.load(open('./db_dep.json'))
remote_models = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db.json"
remote_dependencies = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db_dep.json"
class ModelManager():
def __init__(self, hf_auth=None, download=True):
if download:
try:
logger.init("Model Reference", status="Downloading")
r = requests.get(remote_models)
self.models = r.json()
r = requests.get(remote_dependencies)
self.dependencies = json.load(open('./db_dep.json'))
logger.init_ok("Model Reference", status="OK")
except:
logger.init_err("Model Reference", status="Download Error")
self.models = json.load(open('./db.json'))
self.dependencies = json.load(open('./db_dep.json'))
logger.init_warn("Model Reference", status="Local")
self.available_models = []
self.tainted_models = []
self.available_dependencies = []
self.loaded_models = {}
self.hf_auth = None
self.set_authentication(hf_auth)
def init(self):
dependencies_available = []
for dependency in self.dependencies:
if self.check_available(self.get_dependency_files(dependency)):
dependencies_available.append(dependency)
self.available_dependencies = dependencies_available
models_available = []
for model in self.models:
if self.check_available(self.get_model_files(model)):
models_available.append(model)
self.available_models = models_available
if self.hf_auth is not None:
if 'username' not in self.hf_auth and 'password' not in self.hf_auth:
raise ValueError('hf_auth must contain username and password')
else:
if self.hf_auth['username'] == '' or self.hf_auth['password'] == '':
raise ValueError('hf_auth must contain username and password')
return True
def set_authentication(self, hf_auth=None):
# We do not let No authentication override previously set auth
if not hf_auth and self.hf_auth:
return
self.hf_auth = hf_auth
def get_model(self, model_name):
return self.models.get(model_name)
def get_filtered_models(self, **kwargs):
'''Get all model names.
Can filter based on metadata of the model reference db
'''
filtered_models = self.models
for keyword in kwargs:
iterating_models = filtered_models.copy()
filtered_models = {}
for model in iterating_models:
# logger.debug([keyword,iterating_models[model].get(keyword),kwargs[keyword]])
if iterating_models[model].get(keyword) == kwargs[keyword]:
filtered_models[model] = iterating_models[model]
return filtered_models
def get_filtered_model_names(self, **kwargs):
filtered_models = self.get_filtered_models(**kwargs)
return list(filtered_models.keys())
def get_dependency(self, dependency_name):
return self.dependencies[dependency_name]
def get_model_files(self, model_name):
return self.models[model_name]['config']['files']
def get_dependency_files(self, dependency_name):
return self.dependencies[dependency_name]['config']['files']
def get_model_download(self, model_name):
return self.models[model_name]['config']['download']
def get_dependency_download(self, dependency_name):
return self.dependencies[dependency_name]['config']['download']
def get_available_models(self):
return self.available_models
def get_available_dependencies(self):
return self.available_dependencies
def get_loaded_models(self):
return self.loaded_models
def get_loaded_models_names(self):
return list(self.loaded_models.keys())
def get_loaded_model(self, model_name):
return self.loaded_models[model_name]
def unload_model(self, model_name):
if model_name in self.loaded_models:
del self.loaded_models[model_name]
return True
return False
def unload_all_models(self):
for model in self.loaded_models:
del self.loaded_models[model]
return True
def taint_model(self,model_name):
'''Marks a model as not valid by remiving it from available_models'''
if model_name in self.available_models:
self.available_models.remove(model_name)
self.tainted_models.append(model_name)
def taint_models(self, models):
for model in models:
self.taint_model(model)
def load_model_from_config(self, model_path='', config_path='', map_location="cpu"):
config = OmegaConf.load(config_path)
pl_sd = torch.load(model_path, map_location=map_location)
if "global_step" in pl_sd:
logger.info(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model = model.eval()
del pl_sd, sd, m, u
return model
def load_ckpt(self, model_name='', precision='half', gpu_id=0):
ckpt_path = self.get_model_files(model_name)[0]['path']
config_path = self.get_model_files(model_name)[1]['path']
model = self.load_model_from_config(model_path=ckpt_path, config_path=config_path)
device = torch.device(f"cuda:{gpu_id}")
model = (model if precision=='full' else model.half()).to(device)
torch_gc()
return {'model': model, 'device': device}
def load_realesrgan(self, model_name='', precision='half', gpu_id=0):
RealESRGAN_models = {
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
}
model_path = self.get_model_files(model_name)[0]['path']
device = torch.device(f"cuda:{gpu_id}")
model = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[models[model_name]['name']],
pre_pad=0, half=True if precision == 'half' else False, device=device)
return {'model': model, 'device': device}
def load_gfpgan(self, model_name='', gpu_id=0):
model_path = self.get_model_files(model_name)[0]['path']
device = torch.device(f"cuda:{gpu_id}")
model = GFPGANer(model_path=model_path, upscale=1, arch='clean',
channel_multiplier=2, bg_upsampler=None, device=device)
return {'model': model, 'device': device}
def load_blip(self, model_name='', precision='half', gpu_id=0, blip_image_eval_size=512, vit='base'):
# vit = 'base' or 'large'
model_path = self.get_model_files(model_name)[0]['path']
device = torch.device(f"cuda:{gpu_id}")
model = blip_decoder(pretrained=model_path,
med_config="configs/blip/med_config.json",
image_size=blip_image_eval_size, vit=vit)
model = model.eval()
model = (model if precision=='full' else model.half()).to(device)
return {'model': model, 'device': device}
def load_open_clip(self, model_name='', precision='half', gpu_id=0):
pretrained = self.get_model(model_name)['pretrained_name']
device = torch.device(f"cuda:{gpu_id}")
model, _, preprocesses = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, cache_dir='models/clip')
model = model.eval()
model = (model if precision=='full' else model.half()).to(device)
return {'model': model, 'device': device, 'preprocesses': preprocesses}
def load_clip(self, model_name='', precision='half', gpu_id=0):
device = torch.device(f"cuda:{gpu_id}")
model, preprocesses = clip.load(model_name, device=device, download_root='models/clip')
model = model.eval()
model = (model if precision=='full' else model.half()).to(device)
return {'model': model, 'device': device, 'preprocesses': preprocesses}
def load_model(self, model_name='', precision='half', gpu_id=0):
if model_name not in self.available_models:
return False
if self.models[model_name]['type'] == 'ckpt':
self.loaded_models[model_name] = self.load_ckpt(model_name, precision, gpu_id)
return True
elif self.models[model_name]['type'] == 'realesrgan':
self.loaded_models[model_name] = self.load_realesrgan(model_name, precision, gpu_id)
return True
elif self.models[model_name]['type'] == 'gfpgan':
self.loaded_models[model_name] = self.load_gfpgan(model_name, gpu_id)
return True
elif self.models[model_name]['type'] == 'blip':
self.loaded_models[model_name] = self.load_blip(model_name, precision, gpu_id, 512, 'base')
return True
elif self.models[model_name]['type'] == 'open_clip':
self.loaded_models[model_name] = self.load_open_clip(model_name, precision, gpu_id)
return True
elif self.models[model_name]['type'] == 'clip':
self.loaded_models[model_name] = self.load_clip(model_name, precision, gpu_id)
return True
else:
return False
def validate_model(self, model_name):
files = self.get_model_files(model_name)
all_ok = True
for file_details in files:
if not self.check_file_available(file_details['path']):
return False
if not self.validate_file(file_details):
return False
return True
def validate_file(self, file_details):
if 'md5sum' in file_details:
file_name = file_details['path']
logger.debug(f"Getting md5sum of {file_name}")
with open(file_name, 'rb') as file_to_check:
file_hash = hashlib.md5()
while chunk := file_to_check.read(8192):
file_hash.update(chunk)
if file_details['md5sum'] != file_hash.hexdigest():
return False
return True
def check_file_available(self, file_path):
return os.path.exists(file_path)
def check_available(self, files):
available = True
for file in files:
if not self.check_file_available(file['path']):
available = False
return available
def download_file(self, url, file_path):
# make directory
os.makedirs(os.path.dirname(file_path), exist_ok=True)
pbar_desc = file_path.split('/')[-1]
r = requests.get(url, stream=True)
with open(file_path, 'wb') as f:
with tqdm(
# all optional kwargs
unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
desc=pbar_desc, total=int(r.headers.get('content-length', 0))
) as pbar:
for chunk in r.iter_content(chunk_size=16*1024):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
def download_model(self, model_name):
if model_name in self.available_models:
logger.info(f"{model_name} is already available.")
return True
download = self.get_model_download(model_name)
files = self.get_model_files(model_name)
for i in range(len(download)):
file_path = f"{download[i]['file_path']}/{download[i]['file_name']}" if 'file_path' in download[i] else files[i]['path']
if 'file_url' in download[i]:
download_url = download[i]['file_url']
if 'hf_auth' in download[i]:
username = self.hf_auth['username']
password = self.hf_auth['password']
download_url = download_url.format(username=username, password=password)
if 'file_name' in download[i]:
download_name = download[i]['file_name']
if 'file_path' in download[i]:
download_path = download[i]['file_path']
if 'manual' in download[i]:
logger.warning(f"The model {model_name} requires manual download from {download_url}. Please place it in {download_path}/{download_name} then press ENTER to continue...")
input('')
continue
# TODO: simplify
if "file_content" in download[i]:
file_content = download[i]['file_content']
logger.info(f"writing {file_content} to {file_path}")
# make directory download_path
os.makedirs(download_path, exist_ok=True)
# write file_content to download_path/download_name
with open(os.path.join(download_path, download_name), 'w') as f:
f.write(file_content)
elif 'symlink' in download[i]:
logger.info(f"symlink {file_path} to {download[i]['symlink']}")
symlink = download[i]['symlink']
# make directory symlink
os.makedirs(download_path, exist_ok=True)
# make symlink from download_path/download_name to symlink
os.symlink(symlink, os.path.join(download_path, download_name))
elif 'git' in download[i]:
logger.info(f"git clone {download_url} to {file_path}")
# make directory download_path
os.makedirs(file_path, exist_ok=True)
git.Git(file_path).clone(download_url)
if 'post_process' in download[i]:
for post_process in download[i]['post_process']:
if 'delete' in post_process:
# delete folder post_process['delete']
logger.info(f"delete {post_process['delete']}")
try:
shutil.rmtree(post_process['delete'])
except PermissionError as e:
logger.error(f"[!] Something went wrong while deleting the `{post_process['delete']}`. Please delete it manually.")
logger.error("PermissionError: ", e)
else:
if not self.check_file_available(file_path) or model_name in self.tainted_models:
logger.debug(f'Downloading {download_url} to {file_path}')
self.download_file(download_url, file_path)
if not self.validate_model(model_name):
return False
if model_name in self.tainted_models:
self.tainted_models.remove(model_name)
self.init()
return True
def download_dependency(self, dependency_name):
if dependency_name in self.available_dependencies:
logger.info(f"{dependency_name} is already installed.")
return True
download = self.get_dependency_download(dependency_name)
files = self.get_dependency_files(dependency_name)
for i in range(len(download)):
if "git" in download[i]:
logger.warning("git download not implemented yet")
break
file_path = files[i]['path']
if 'file_url' in download[i]:
download_url = download[i]['file_url']
if 'file_name' in download[i]:
download_name = download[i]['file_name']
if 'file_path' in download[i]:
download_path = download[i]['file_path']
logger.debug(download_name)
if "unzip" in download[i]:
zip_path = f'temp/{download_name}.zip'
# os dirname zip_path
# mkdir temp
os.makedirs("temp", exist_ok=True)
self.download_file(download_url, zip_path)
logger.info(f"unzip {zip_path}")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall('temp/')
# move temp/sd-concepts-library-main/sd-concepts-library to download_path
logger.info(f"move temp/{download_name}-main/{download_name} to {download_path}")
shutil.move(f"temp/{download_name}-main/{download_name}", download_path)
logger.info(f"delete {zip_path}")
os.remove(zip_path)
logger.info(f"delete temp/{download_name}-main/")
shutil.rmtree(f"temp/{download_name}-main")
else:
if not self.check_file_available(file_path):
logger.init(f'{file_path}', status="Downloading")
self.download_file(download_url, file_path)
self.init()
return True
def download_all_models(self):
for model in self.get_filtered_model_names(download_all = True):
if not self.check_model_available(model):
logger.init(f"{model}", status="Downloading")
self.download_model(model)
else:
logger.info(f"{model} is already downloaded.")
return True
def download_all_dependencies(self):
for dependency in self.dependencies:
if not self.check_dependency_available(dependency):
logger.init(f"{dependency}",status="Downloading")
self.download_dependency(dependency)
else:
logger.info(f"{dependency} is already installed.")
return True
def download_all(self):
self.download_all_dependencies()
self.download_all_models()
return True
def check_all_available(self):
for model in self.models:
if not self.check_available(self.get_model_files(model)):
return False
for dependency in self.dependencies:
if not self.check_available(self.get_dependency_files(dependency)):
return False
return True
def check_model_available(self, model_name):
if model_name not in self.models:
return False
return self.check_available(self.get_model_files(model_name))
def check_dependency_available(self, dependency_name):
if dependency_name not in self.dependencies:
return False
return self.check_available(self.get_dependency_files(dependency_name))
def check_all_available(self):
for model in self.models:
if not self.check_model_available(model):
return False
for dependency in self.dependencies:
if not self.check_dependency_available(dependency):
return False
return True

View File

@ -1,48 +0,0 @@
# Class realesrgan
# Inputs:
# - model
# - device
# - output_dir
# - output_ext
# outupts:
# - output_images
import PIL
from torchvision import transforms
import numpy as np
import os
import cv2
from nataili.util.save_sample import save_sample
class realesrgan:
def __init__(self, model, device, output_dir, output_ext='jpg'):
self.model = model
self.device = device
self.output_dir = output_dir
self.output_ext = output_ext
self.output_images = []
def generate(self, input_image):
# load image
img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
else:
img_mode = None
# upscale
output, _ = self.model.enhance(img)
if img_mode == 'RGBA': # RGBA images should be saved in png format
self.output_ext = 'png'
esrgan_sample = output[:,:,::-1]
esrgan_image = PIL.Image.fromarray(esrgan_sample)
# append model name to output image name
filename = os.path.basename(input_image)
filename = os.path.splitext(filename)[0]
filename = f'{filename}_esrgan'
filename_with_ext = f'{filename}.{self.output_ext}'
output_image = os.path.join(self.output_dir, filename_with_ext)
save_sample(esrgan_image, filename, self.output_dir, self.output_ext)
self.output_images.append(output_image)
return

View File

@ -1,48 +0,0 @@
# Class realesrgan
# Inputs:
# - model
# - device
# - output_dir
# - output_ext
# outupts:
# - output_images
import PIL
from torchvision import transforms
import numpy as np
import os
import cv2
from nataili.util.save_sample import save_sample
class realesrgan:
def __init__(self, model, device, output_dir, output_ext='jpg'):
self.model = model
self.device = device
self.output_dir = output_dir
self.output_ext = output_ext
self.output_images = []
def generate(self, input_image):
# load image
img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
else:
img_mode = None
# upscale
output, _ = self.model.enhance(img)
if img_mode == 'RGBA': # RGBA images should be saved in png format
self.output_ext = 'png'
esrgan_sample = output[:,:,::-1]
esrgan_image = PIL.Image.fromarray(esrgan_sample)
# append model name to output image name
filename = os.path.basename(input_image)
filename = os.path.splitext(filename)[0]
filename = f'{filename}_esrgan'
filename_with_ext = f'{filename}.{self.output_ext}'
output_image = os.path.join(self.output_dir, filename_with_ext)
save_sample(esrgan_image, filename, self.output_dir, self.output_ext)
self.output_images.append(output_image)
return

View File

@ -1 +0,0 @@
from nataili.util.logger import logger,set_logger_verbosity, quiesce_logger, test_logger

View File

@ -1,16 +0,0 @@
import gc
import torch
import threading
import pynvml
import time
with torch.no_grad():
def torch_gc():
for _ in range(2):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()

View File

@ -1,18 +0,0 @@
def check_prompt_length(model, prompt, comments):
"""this function tests if prompt is too long, and if so, adds a message to comments"""
tokenizer = model.cond_stage_model.tokenizer
max_length = model.cond_stage_model.max_length
info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
ovf = info['overflowing_tokens'][0]
overflowing_count = ovf.shape[0]
if overflowing_count == 0:
return
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
del tokenizer

View File

@ -1,22 +0,0 @@
from pathlib import Path
def get_next_sequence_number(path, prefix=''):
"""
Determines and returns the next sequence number to use when saving an
image in the specified directory.
If a prefix is given, only consider files whose names start with that
prefix, and strip the prefix from filenames before extracting their
sequence number.
The sequence starts at 0.
"""
result = -1
for p in Path(path).iterdir():
if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix):
tmp = p.name[len(prefix):]
try:
result = max(int(tmp.split('-')[0]), result)
except ValueError:
pass
return result + 1

View File

@ -1,21 +0,0 @@
import math
import PIL
def image_grid(imgs, n_rows=None):
if n_rows is not None:
rows = n_rows
else:
rows = math.sqrt(len(imgs))
rows = round(rows)
cols = math.ceil(len(imgs) / rows)
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols * w, rows * h), color='black')
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid

View File

@ -1,40 +0,0 @@
import os
import torch
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# separate token and the embeds
if learned_embeds_path.endswith('.pt'):
# old format
# token = * so replace with file directory name when converting
trained_token = os.path.basename(learned_embeds_path)
params_dict = {
trained_token: torch.tensor(list(loaded_learned_embeds['string_to_param'].items())[0][1])
}
learned_embeds_path = os.path.splitext(learned_embeds_path)[0] + '.bin'
torch.save(params_dict, learned_embeds_path)
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
elif learned_embeds_path.endswith('.bin'):
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
embeds = loaded_learned_embeds[trained_token]
# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)
# add the token in tokenizer
token = token if token is not None else trained_token
num_added_tokens = tokenizer.add_tokens(token)
# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
return token

View File

@ -1,102 +0,0 @@
import sys
from functools import partialmethod
from loguru import logger
STDOUT_LEVELS = ["GENERATION", "PROMPT"]
INIT_LEVELS = ["INIT", "INIT_OK", "INIT_WARN", "INIT_ERR"]
MESSAGE_LEVELS = ["MESSAGE"]
# By default we're at error level or higher
verbosity = 20
quiet = 0
def set_logger_verbosity(count):
global verbosity
# The count comes reversed. So count = 0 means minimum verbosity
# While count 5 means maximum verbosity
# So the more count we have, the lowe we drop the versbosity maximum
verbosity = 20 - (count * 10)
def quiesce_logger(count):
global quiet
# The bigger the count, the more silent we want our logger
quiet = count * 10
def is_stdout_log(record):
if record["level"].name not in STDOUT_LEVELS:
return(False)
if record["level"].no < verbosity + quiet:
return(False)
return(True)
def is_init_log(record):
if record["level"].name not in INIT_LEVELS:
return(False)
if record["level"].no < verbosity + quiet:
return(False)
return(True)
def is_msg_log(record):
if record["level"].name not in MESSAGE_LEVELS:
return(False)
if record["level"].no < verbosity + quiet:
return(False)
return(True)
def is_stderr_log(record):
if record["level"].name in STDOUT_LEVELS + INIT_LEVELS + MESSAGE_LEVELS:
return(False)
if record["level"].no < verbosity + quiet:
return(False)
return(True)
def test_logger():
logger.generation("This is a generation message\nIt is typically multiline\nThee Lines".encode("unicode_escape").decode("utf-8"))
logger.prompt("This is a prompt message")
logger.debug("Debug Message")
logger.info("Info Message")
logger.warning("Info Warning")
logger.error("Error Message")
logger.critical("Critical Message")
logger.init("This is an init message", status="Starting")
logger.init_ok("This is an init message", status="OK")
logger.init_warn("This is an init message", status="Warning")
logger.init_err("This is an init message", status="Error")
logger.message("This is user message")
sys.exit()
logfmt = "<level>{level: <10}</level> | <green>{time:YYYY-MM-DD HH:mm:ss}</green> | <green>{name}</green>:<green>{function}</green>:<green>{line}</green> - <level>{message}</level>"
genfmt = "<level>{level: <10}</level> @ <green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{message}</level>"
initfmt = "<magenta>INIT </magenta> | <level>{extra[status]: <11}</level> | <magenta>{message}</magenta>"
msgfmt = "<level>{level: <10}</level> | <level>{message}</level>"
try:
logger.level("GENERATION", no=24, color="<cyan>")
logger.level("PROMPT", no=23, color="<yellow>")
logger.level("INIT", no=31, color="<white>")
logger.level("INIT_OK", no=31, color="<green>")
logger.level("INIT_WARN", no=31, color="<yellow>")
logger.level("INIT_ERR", no=31, color="<red>")
# Messages contain important information without which this application might not be able to be used
# As such, they have the highest priority
logger.level("MESSAGE", no=61, color="<green>")
except TypeError:
pass
logger.__class__.generation = partialmethod(logger.__class__.log, "GENERATION")
logger.__class__.prompt = partialmethod(logger.__class__.log, "PROMPT")
logger.__class__.init = partialmethod(logger.__class__.log, "INIT")
logger.__class__.init_ok = partialmethod(logger.__class__.log, "INIT_OK")
logger.__class__.init_warn = partialmethod(logger.__class__.log, "INIT_WARN")
logger.__class__.init_err = partialmethod(logger.__class__.log, "INIT_ERR")
logger.__class__.message = partialmethod(logger.__class__.log, "MESSAGE")
config = {
"handlers": [
{"sink": sys.stderr, "format": logfmt, "colorize":True, "filter": is_stderr_log},
{"sink": sys.stdout, "format": genfmt, "level": "PROMPT", "colorize":True, "filter": is_stdout_log},
{"sink": sys.stdout, "format": initfmt, "level": "INIT", "colorize":True, "filter": is_init_log},
{"sink": sys.stdout, "format": msgfmt, "level": "MESSAGE", "colorize":True, "filter": is_msg_log}
],
}
logger.configure(**config)

View File

@ -1,20 +0,0 @@
import os
def save_sample(image, filename, sample_path, extension='png', jpg_quality=95, webp_quality=95, webp_lossless=True, png_compression=9):
path = os.path.join(sample_path, filename + '.' + extension)
if os.path.exists(path):
return False
if not os.path.exists(sample_path):
os.makedirs(sample_path)
if extension == 'png':
image.save(path, format='PNG', compress_level=png_compression)
elif extension == 'jpg':
image.save(path, quality=jpg_quality, optimize=True)
elif extension == 'webp':
image.save(path, quality=webp_quality, lossless=webp_lossless)
else:
return False
if os.path.exists(path):
return True
else:
return False

View File

@ -1,22 +0,0 @@
import random
def seed_to_int(s):
if type(s) is int:
return s
if s is None or s == '':
return random.randint(0, 2**32 - 1)
if type(s) is list:
seed_list = []
for seed in s:
if seed is None or seed == '':
seed_list.append(random.randint(0, 2**32 - 1))
else:
seed_list = s
return seed_list
n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1))
while n >= 2**32:
n = n >> 32
return n

View File

@ -75,15 +75,9 @@ from pathlib import Path
from huggingface_hub import hf_hub_download
#import librosa
#from logger import logger, set_logger_verbosity, quiesce_logger
from logger import logger, set_logger_verbosity, quiesce_logger
#from loguru import logger
#from nataili.inference.compvis.img2img import img2img
#from nataili.model_manager import ModelManager
#from nataili.inference.compvis.txt2img import txt2img
from nataili.util.cache import torch_gc
from nataili.util.logger import logger, set_logger_verbosity, quiesce_logger
try:
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet