mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-14 22:13:41 +03:00
Added nataili as part of the repo under the scripts folder.
This commit is contained in:
parent
40d9d3ea93
commit
0110973b68
0
scripts/nataili/__init__.py
Normal file
0
scripts/nataili/__init__.py
Normal file
0
scripts/nataili/inference/__init__.py
Normal file
0
scripts/nataili/inference/__init__.py
Normal file
0
scripts/nataili/inference/compvis/__init__.py
Normal file
0
scripts/nataili/inference/compvis/__init__.py
Normal file
551
scripts/nataili/inference/compvis/img2img.py
Normal file
551
scripts/nataili/inference/compvis/img2img.py
Normal file
@ -0,0 +1,551 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import k_diffusion as K
|
||||
import tqdm
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import skimage
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.kdiffusion import CFGMaskedDenoiser, KDiffusionSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from nataili.util.cache import torch_gc
|
||||
from nataili.util.check_prompt_length import check_prompt_length
|
||||
from nataili.util.get_next_sequence_number import get_next_sequence_number
|
||||
from nataili.util.image_grid import image_grid
|
||||
from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip
|
||||
from nataili.util.save_sample import save_sample
|
||||
from nataili.util.seed_to_int import seed_to_int
|
||||
from slugify import slugify
|
||||
import PIL
|
||||
|
||||
|
||||
class img2img:
|
||||
def __init__(self, model, device, output_dir, save_extension='jpg',
|
||||
output_file_path=False, load_concepts=False, concepts_dir=None,
|
||||
verify_input=True, auto_cast=True):
|
||||
self.model = model
|
||||
self.output_dir = output_dir
|
||||
self.output_file_path = output_file_path
|
||||
self.save_extension = save_extension
|
||||
self.load_concepts = load_concepts
|
||||
self.concepts_dir = concepts_dir
|
||||
self.verify_input = verify_input
|
||||
self.auto_cast = auto_cast
|
||||
self.device = device
|
||||
self.comments = []
|
||||
self.output_images = []
|
||||
self.info = ''
|
||||
self.stats = ''
|
||||
self.images = []
|
||||
|
||||
def create_random_tensors(self, shape, seeds):
|
||||
xs = []
|
||||
for seed in seeds:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# randn results depend on device; gpu and cpu get different results for same seed;
|
||||
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
||||
# but the original script had it like this so i do not dare change it for now because
|
||||
# it will break everyone's seeds.
|
||||
xs.append(torch.randn(shape, device=self.device))
|
||||
x = torch.stack(xs)
|
||||
return x
|
||||
|
||||
def process_prompt_tokens(self, prompt_tokens):
|
||||
# compviz codebase
|
||||
tokenizer = self.model.cond_stage_model.tokenizer
|
||||
text_encoder = self.model.cond_stage_model.transformer
|
||||
|
||||
# diffusers codebase
|
||||
#tokenizer = pipe.tokenizer
|
||||
#text_encoder = pipe.text_encoder
|
||||
|
||||
ext = ('.pt', '.bin')
|
||||
for token_name in prompt_tokens:
|
||||
embedding_path = os.path.join(self.concepts_dir, token_name)
|
||||
if os.path.exists(embedding_path):
|
||||
for files in os.listdir(embedding_path):
|
||||
if files.endswith(ext):
|
||||
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
|
||||
else:
|
||||
print(f"Concept {token_name} not found in {self.concepts_dir}")
|
||||
del tokenizer, text_encoder
|
||||
return
|
||||
del tokenizer, text_encoder
|
||||
|
||||
def resize_image(self, resize_mode, im, width, height):
|
||||
LANCZOS = (PIL.Image.Resampling.LANCZOS if hasattr(PIL.Image, 'Resampling') else PIL.Image.LANCZOS)
|
||||
if resize_mode == "resize":
|
||||
res = im.resize((width, height), resample=LANCZOS)
|
||||
elif resize_mode == "crop":
|
||||
ratio = width / height
|
||||
src_ratio = im.width / im.height
|
||||
|
||||
src_w = width if ratio > src_ratio else im.width * height // im.height
|
||||
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
||||
|
||||
resized = im.resize((src_w, src_h), resample=LANCZOS)
|
||||
res = PIL.Image.new("RGBA", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
else:
|
||||
ratio = width / height
|
||||
src_ratio = im.width / im.height
|
||||
|
||||
src_w = width if ratio < src_ratio else im.width * height // im.height
|
||||
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
||||
|
||||
resized = im.resize((src_w, src_h), resample=LANCZOS)
|
||||
res = PIL.Image.new("RGBA", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
|
||||
if ratio < src_ratio:
|
||||
fill_height = height // 2 - src_h // 2
|
||||
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
||||
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
||||
elif ratio > src_ratio:
|
||||
fill_width = width // 2 - src_w // 2
|
||||
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
||||
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
||||
|
||||
return res
|
||||
|
||||
#
|
||||
# helper fft routines that keep ortho normalization and auto-shift before and after fft
|
||||
def _fft2(self, data):
|
||||
if data.ndim > 2: # has channels
|
||||
out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:,:,c]
|
||||
out_fft[:,:,c] = np.fft.fft2(np.fft.fftshift(c_data),norm="ortho")
|
||||
out_fft[:,:,c] = np.fft.ifftshift(out_fft[:,:,c])
|
||||
else: # one channel
|
||||
out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
|
||||
out_fft[:,:] = np.fft.fft2(np.fft.fftshift(data),norm="ortho")
|
||||
out_fft[:,:] = np.fft.ifftshift(out_fft[:,:])
|
||||
|
||||
return out_fft
|
||||
|
||||
def _ifft2(self, data):
|
||||
if data.ndim > 2: # has channels
|
||||
out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:,:,c]
|
||||
out_ifft[:,:,c] = np.fft.ifft2(np.fft.fftshift(c_data),norm="ortho")
|
||||
out_ifft[:,:,c] = np.fft.ifftshift(out_ifft[:,:,c])
|
||||
else: # one channel
|
||||
out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
|
||||
out_ifft[:,:] = np.fft.ifft2(np.fft.fftshift(data),norm="ortho")
|
||||
out_ifft[:,:] = np.fft.ifftshift(out_ifft[:,:])
|
||||
|
||||
return out_ifft
|
||||
|
||||
def _get_gaussian_window(self, width, height, std=3.14, mode=0):
|
||||
|
||||
window_scale_x = float(width / min(width, height))
|
||||
window_scale_y = float(height / min(width, height))
|
||||
|
||||
window = np.zeros((width, height))
|
||||
x = (np.arange(width) / width * 2. - 1.) * window_scale_x
|
||||
for y in range(height):
|
||||
fy = (y / height * 2. - 1.) * window_scale_y
|
||||
if mode == 0:
|
||||
window[:, y] = np.exp(-(x**2+fy**2) * std)
|
||||
else:
|
||||
window[:, y] = (1/((x**2+1.) * (fy**2+1.))) ** (std/3.14) # hey wait a minute that's not gaussian
|
||||
|
||||
return window
|
||||
|
||||
def _get_masked_window_rgb(self, np_mask_grey, hardness=1.):
|
||||
np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
|
||||
if hardness != 1.:
|
||||
hardened = np_mask_grey[:] ** hardness
|
||||
else:
|
||||
hardened = np_mask_grey[:]
|
||||
for c in range(3):
|
||||
np_mask_rgb[:,:,c] = hardened[:]
|
||||
return np_mask_rgb
|
||||
|
||||
def get_matched_noise(self, _np_src_image, np_mask_rgb, noise_q, color_variation):
|
||||
"""
|
||||
Explanation:
|
||||
Getting good results in/out-painting with stable diffusion can be challenging.
|
||||
Although there are simpler effective solutions for in-painting, out-painting can be especially challenging because there is no color data
|
||||
in the masked area to help prompt the generator. Ideally, even for in-painting we'd like work effectively without that data as well.
|
||||
Provided here is my take on a potential solution to this problem.
|
||||
|
||||
By taking a fourier transform of the masked src img we get a function that tells us the presence and orientation of each feature scale in the unmasked src.
|
||||
Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations, and positions increases output coherence
|
||||
by helping keep features aligned. This technique is applicable to any continuous generation task such as audio or video, each of which can
|
||||
be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased. For multi-channel data such as color
|
||||
or stereo sound the "color tone" or histogram of the seed noise can be matched to improve quality (using scikit-image currently)
|
||||
This method is quite robust and has the added benefit of being fast independently of the size of the out-painted area.
|
||||
The effects of this method include things like helping the generator integrate the pre-existing view distance and camera angle.
|
||||
|
||||
Carefully managing color and brightness with histogram matching is also essential to achieving good coherence.
|
||||
|
||||
noise_q controls the exponent in the fall-off of the distribution can be any positive number, lower values means higher detail (range > 0, default 1.)
|
||||
color_variation controls how much freedom is allowed for the colors/palette of the out-painted area (range 0..1, default 0.01)
|
||||
This code is provided as is under the Unlicense (https://unlicense.org/)
|
||||
Although you have no obligation to do so, if you found this code helpful please find it in your heart to credit me [parlance-zz].
|
||||
|
||||
Questions or comments can be sent to parlance@fifth-harmonic.com (https://github.com/parlance-zz/)
|
||||
This code is part of a new branch of a discord bot I am working on integrating with diffusers (https://github.com/parlance-zz/g-diffuser-bot)
|
||||
|
||||
"""
|
||||
|
||||
global DEBUG_MODE
|
||||
global TMP_ROOT_PATH
|
||||
|
||||
width = _np_src_image.shape[0]
|
||||
height = _np_src_image.shape[1]
|
||||
num_channels = _np_src_image.shape[2]
|
||||
|
||||
np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
|
||||
np_mask_grey = (np.sum(np_mask_rgb, axis=2)/3.)
|
||||
np_src_grey = (np.sum(np_src_image, axis=2)/3.)
|
||||
all_mask = np.ones((width, height), dtype=bool)
|
||||
img_mask = np_mask_grey > 1e-6
|
||||
ref_mask = np_mask_grey < 1e-3
|
||||
|
||||
windowed_image = _np_src_image * (1.-self._get_masked_window_rgb(np_mask_grey))
|
||||
windowed_image /= np.max(windowed_image)
|
||||
windowed_image += np.average(_np_src_image) * np_mask_rgb# / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
|
||||
#windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) / (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area
|
||||
#_save_debug_img(windowed_image, "windowed_src_img")
|
||||
|
||||
src_fft = self._fft2(windowed_image) # get feature statistics from masked src img
|
||||
src_dist = np.absolute(src_fft)
|
||||
src_phase = src_fft / src_dist
|
||||
#_save_debug_img(src_dist, "windowed_src_dist")
|
||||
|
||||
noise_window = self._get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
|
||||
noise_rgb = np.random.random_sample((width, height, num_channels))
|
||||
noise_grey = (np.sum(noise_rgb, axis=2)/3.)
|
||||
noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
|
||||
for c in range(num_channels):
|
||||
noise_rgb[:,:,c] += (1. - color_variation) * noise_grey
|
||||
|
||||
noise_fft = self._fft2(noise_rgb)
|
||||
for c in range(num_channels):
|
||||
noise_fft[:,:,c] *= noise_window
|
||||
noise_rgb = np.real(self._ifft2(noise_fft))
|
||||
shaped_noise_fft = self._fft2(noise_rgb)
|
||||
shaped_noise_fft[:,:,:] = np.absolute(shaped_noise_fft[:,:,:])**2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
|
||||
|
||||
brightness_variation = 0.#color_variation # todo: temporarily tieing brightness variation to color variation for now
|
||||
contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
|
||||
|
||||
# scikit-image is used for histogram matching, very convenient!
|
||||
shaped_noise = np.real(self._ifft2(shaped_noise_fft))
|
||||
shaped_noise -= np.min(shaped_noise)
|
||||
shaped_noise /= np.max(shaped_noise)
|
||||
shaped_noise[img_mask,:] = skimage.exposure.match_histograms(shaped_noise[img_mask,:]**1., contrast_adjusted_np_src[ref_mask,:], channel_axis=1)
|
||||
shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
|
||||
#_save_debug_img(shaped_noise, "shaped_noise")
|
||||
|
||||
matched_noise = np.zeros((width, height, num_channels))
|
||||
matched_noise = shaped_noise[:]
|
||||
#matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:], _np_src_image[ref_mask,:], channel_axis=1)
|
||||
#matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb
|
||||
|
||||
#_save_debug_img(matched_noise, "matched_noise")
|
||||
|
||||
"""
|
||||
todo:
|
||||
color_variation doesnt have to be a single number, the overall color tone of the out-painted area could be param controlled
|
||||
"""
|
||||
|
||||
return np.clip(matched_noise, 0., 1.)
|
||||
|
||||
def find_noise_for_image(self, model, device, init_image, prompt, steps=200, cond_scale=2.0, verbose=False, normalize=False, generation_callback=None):
|
||||
image = np.array(init_image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
image = 2. * image - 1.
|
||||
image = image.to(device)
|
||||
x = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||
|
||||
uncond = model.get_learned_conditioning([''])
|
||||
cond = model.get_learned_conditioning([prompt])
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
dnw = K.external.CompVisDenoiser(model)
|
||||
sigmas = dnw.get_sigmas(steps).flip(0)
|
||||
|
||||
if verbose:
|
||||
print(sigmas)
|
||||
|
||||
for i in tqdm.trange(1, len(sigmas)):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
|
||||
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
|
||||
|
||||
if i == 1:
|
||||
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
|
||||
else:
|
||||
t = dnw.sigma_to_t(sigma_in)
|
||||
|
||||
eps = model.apply_model(x_in * c_in, t, cond=cond_in)
|
||||
denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
|
||||
|
||||
denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale
|
||||
|
||||
if i == 1:
|
||||
d = (x - denoised) / (2 * sigmas[i])
|
||||
else:
|
||||
d = (x - denoised) / sigmas[i - 1]
|
||||
|
||||
dt = sigmas[i] - sigmas[i - 1]
|
||||
x = x + d * dt
|
||||
|
||||
return x / sigmas[-1]
|
||||
|
||||
def generate(self, prompt: str, init_img=None, init_mask=None, mask_mode='mask', resize_mode='resize', noise_mode='seed',
|
||||
denoising_strength:float=0.8, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None,
|
||||
height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0):
|
||||
seed = seed_to_int(seed)
|
||||
image_dict = {
|
||||
"seed": seed
|
||||
}
|
||||
# Init image is assumed to be a PIL image
|
||||
init_img = self.resize_image('resize', init_img, width, height)
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(self.model)
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(self.model)
|
||||
elif sampler_name == 'k_dpm_2_a':
|
||||
sampler = KDiffusionSampler(self.model,'dpm_2_ancestral')
|
||||
elif sampler_name == 'k_dpm_2':
|
||||
sampler = KDiffusionSampler(self.model,'dpm_2')
|
||||
elif sampler_name == 'k_euler_a':
|
||||
sampler = KDiffusionSampler(self.model,'euler_ancestral')
|
||||
elif sampler_name == 'k_euler':
|
||||
sampler = KDiffusionSampler(self.model,'euler')
|
||||
elif sampler_name == 'k_heun':
|
||||
sampler = KDiffusionSampler(self.model,'heun')
|
||||
elif sampler_name == 'k_lms':
|
||||
sampler = KDiffusionSampler(self.model,'lms')
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
torch_gc()
|
||||
def process_init_mask(init_mask: PIL.Image):
|
||||
if init_mask.mode == "RGBA":
|
||||
init_mask = init_mask.convert('RGBA')
|
||||
background = PIL.Image.new('RGBA', init_mask.size, (0, 0, 0))
|
||||
init_mask = PIL.Image.alpha_composite(background, init_mask)
|
||||
init_mask = init_mask.convert('RGB')
|
||||
return init_mask
|
||||
|
||||
if mask_mode == "mask":
|
||||
if init_mask:
|
||||
init_mask = process_init_mask(init_mask)
|
||||
elif mask_mode == "invert":
|
||||
if init_mask:
|
||||
init_mask = process_init_mask(init_mask)
|
||||
init_mask = PIL.ImageOps.invert(init_mask)
|
||||
elif mask_mode == "alpha":
|
||||
init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
init_mask = init_img_transparency
|
||||
init_mask = init_mask.convert("RGB")
|
||||
init_mask = self.resize_image(resize_mode, init_mask, width, height)
|
||||
init_mask = init_mask.convert("RGB")
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(denoising_strength * ddim_steps)
|
||||
|
||||
if init_mask is not None and (noise_mode == "matched" or noise_mode == "find_and_matched") and init_img is not None:
|
||||
noise_q = 0.99
|
||||
color_variation = 0.0
|
||||
mask_blend_factor = 1.0
|
||||
|
||||
np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
|
||||
np_mask_rgb = 1. - (np.asarray(PIL.ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
|
||||
np_mask_rgb -= np.min(np_mask_rgb)
|
||||
np_mask_rgb /= np.max(np_mask_rgb)
|
||||
np_mask_rgb = 1. - np_mask_rgb
|
||||
np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
|
||||
blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
|
||||
blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
|
||||
#np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
|
||||
#np_mask_rgb = np_mask_rgb + blurred
|
||||
np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
|
||||
np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
|
||||
|
||||
noise_rgb = self.get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
|
||||
blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
|
||||
noised = noise_rgb[:]
|
||||
blend_mask_rgb **= (2.)
|
||||
noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
|
||||
|
||||
np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
|
||||
ref_mask = np_mask_grey < 1e-3
|
||||
|
||||
all_mask = np.ones((height, width), dtype=bool)
|
||||
noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
|
||||
|
||||
init_img = PIL.Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
|
||||
|
||||
def init():
|
||||
image = init_img.convert('RGB')
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
mask_channel = None
|
||||
if init_mask:
|
||||
alpha = self.resize_image(resize_mode, init_mask, width // 8, height // 8)
|
||||
mask_channel = alpha.split()[-1]
|
||||
|
||||
mask = None
|
||||
if mask_channel is not None:
|
||||
mask = np.array(mask_channel).astype(np.float32) / 255.0
|
||||
mask = (1 - mask)
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3)
|
||||
mask = torch.from_numpy(mask).to(self.model.device)
|
||||
|
||||
init_image = 2. * image - 1.
|
||||
init_image = init_image.to(self.model.device)
|
||||
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
return init_latent, mask,
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
t_enc_steps = t_enc
|
||||
obliterate = False
|
||||
if ddim_steps == t_enc_steps:
|
||||
t_enc_steps = t_enc_steps - 1
|
||||
obliterate = True
|
||||
|
||||
if sampler_name != 'DDIM':
|
||||
x0, z_mask = init_data
|
||||
|
||||
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
|
||||
noise = x * sigmas[ddim_steps - t_enc_steps - 1]
|
||||
|
||||
xi = x0 + noise
|
||||
|
||||
# Obliterate masked image
|
||||
if z_mask is not None and obliterate:
|
||||
random = torch.randn(z_mask.shape, device=xi.device)
|
||||
xi = (z_mask * noise) + ((1-z_mask) * xi)
|
||||
|
||||
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
|
||||
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
|
||||
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
|
||||
'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
|
||||
else:
|
||||
|
||||
x0, z_mask = init_data
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(self.model.device))
|
||||
|
||||
# Obliterate masked image
|
||||
if z_mask is not None and obliterate:
|
||||
random = torch.randn(z_mask.shape, device=z_enc.device)
|
||||
z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
|
||||
|
||||
# decode it
|
||||
samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
z_mask=z_mask, x0=x0)
|
||||
return samples_ddim
|
||||
|
||||
torch_gc()
|
||||
|
||||
if self.load_concepts and self.concepts_dir is not None:
|
||||
prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
|
||||
if prompt_tokens:
|
||||
self.process_prompt_tokens(prompt_tokens)
|
||||
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
sample_path = os.path.join(self.output_dir, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
|
||||
if self.verify_input:
|
||||
try:
|
||||
check_prompt_length(self.model, prompt, self.comments)
|
||||
except:
|
||||
import traceback
|
||||
print("Error verifying input:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
all_prompts = batch_size * n_iter * [prompt]
|
||||
all_seeds = [seed + x for x in range(len(all_prompts))]
|
||||
|
||||
precision_scope = torch.autocast if self.auto_cast else nullcontext
|
||||
|
||||
with torch.no_grad(), precision_scope("cuda"):
|
||||
for n in range(n_iter):
|
||||
print(f"Iteration: {n+1}/{n_iter}")
|
||||
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
|
||||
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
|
||||
|
||||
uc = self.model.get_learned_conditioning(len(prompts) * [''])
|
||||
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
c = self.model.get_learned_conditioning(prompts)
|
||||
|
||||
opt_C = 4
|
||||
opt_f = 8
|
||||
shape = [opt_C, height // opt_f, width // opt_f]
|
||||
|
||||
x = self.create_random_tensors(shape, seeds=seeds)
|
||||
init_data = init()
|
||||
samples_ddim = sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
|
||||
|
||||
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
for i, x_sample in enumerate(x_samples_ddim):
|
||||
sanitized_prompt = slugify(prompts[i])
|
||||
full_path = os.path.join(os.getcwd(), sample_path)
|
||||
sample_path_i = sample_path
|
||||
base_count = get_next_sequence_number(sample_path_i)
|
||||
filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)]
|
||||
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
image = PIL.Image.fromarray(x_sample)
|
||||
image_dict['image'] = image
|
||||
self.images.append(image_dict)
|
||||
|
||||
if save_individual_images:
|
||||
path = os.path.join(sample_path, filename + '.' + self.save_extension)
|
||||
success = save_sample(image, filename, sample_path_i, self.save_extension)
|
||||
if success:
|
||||
if self.output_file_path:
|
||||
self.output_images.append(path)
|
||||
else:
|
||||
self.output_images.append(image)
|
||||
else:
|
||||
return
|
||||
|
||||
self.info = f"""
|
||||
{prompt}
|
||||
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}
|
||||
""".strip()
|
||||
self.stats = f'''
|
||||
'''
|
||||
|
||||
for comment in self.comments:
|
||||
self.info += "\n\n" + comment
|
||||
|
||||
torch_gc()
|
||||
|
||||
del sampler
|
||||
|
||||
return
|
201
scripts/nataili/inference/compvis/txt2img.py
Normal file
201
scripts/nataili/inference/compvis/txt2img.py
Normal file
@ -0,0 +1,201 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.kdiffusion import KDiffusionSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from nataili.util.cache import torch_gc
|
||||
from nataili.util.check_prompt_length import check_prompt_length
|
||||
from nataili.util.get_next_sequence_number import get_next_sequence_number
|
||||
from nataili.util.image_grid import image_grid
|
||||
from nataili.util.load_learned_embed_in_clip import load_learned_embed_in_clip
|
||||
from nataili.util.save_sample import save_sample
|
||||
from nataili.util.seed_to_int import seed_to_int
|
||||
from slugify import slugify
|
||||
|
||||
|
||||
class txt2img:
|
||||
def __init__(self, model, device, output_dir, save_extension='jpg',
|
||||
output_file_path=False, load_concepts=False, concepts_dir=None,
|
||||
verify_input=True, auto_cast=True):
|
||||
self.model = model
|
||||
self.output_dir = output_dir
|
||||
self.output_file_path = output_file_path
|
||||
self.save_extension = save_extension
|
||||
self.load_concepts = load_concepts
|
||||
self.concepts_dir = concepts_dir
|
||||
self.verify_input = verify_input
|
||||
self.auto_cast = auto_cast
|
||||
self.device = device
|
||||
self.comments = []
|
||||
self.output_images = []
|
||||
self.info = ''
|
||||
self.stats = ''
|
||||
self.images = []
|
||||
|
||||
def create_random_tensors(self, shape, seeds):
|
||||
xs = []
|
||||
for seed in seeds:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# randn results depend on device; gpu and cpu get different results for same seed;
|
||||
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
||||
# but the original script had it like this so i do not dare change it for now because
|
||||
# it will break everyone's seeds.
|
||||
xs.append(torch.randn(shape, device=self.device))
|
||||
x = torch.stack(xs)
|
||||
return x
|
||||
|
||||
def process_prompt_tokens(self, prompt_tokens):
|
||||
# compviz codebase
|
||||
tokenizer = self.model.cond_stage_model.tokenizer
|
||||
text_encoder = self.model.cond_stage_model.transformer
|
||||
|
||||
# diffusers codebase
|
||||
#tokenizer = pipe.tokenizer
|
||||
#text_encoder = pipe.text_encoder
|
||||
|
||||
ext = ('.pt', '.bin')
|
||||
for token_name in prompt_tokens:
|
||||
embedding_path = os.path.join(self.concepts_dir, token_name)
|
||||
if os.path.exists(embedding_path):
|
||||
for files in os.listdir(embedding_path):
|
||||
if files.endswith(ext):
|
||||
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
|
||||
else:
|
||||
print(f"Concept {token_name} not found in {self.concepts_dir}")
|
||||
del tokenizer, text_encoder
|
||||
return
|
||||
del tokenizer, text_encoder
|
||||
|
||||
def generate(self, prompt: str, ddim_steps=50, sampler_name='k_lms', n_iter=1, batch_size=1, cfg_scale=7.5, seed=None,
|
||||
height=512, width=512, save_individual_images: bool = True, save_grid: bool = True, ddim_eta:float = 0.0):
|
||||
seed = seed_to_int(seed)
|
||||
|
||||
image_dict = {
|
||||
"seed": seed
|
||||
}
|
||||
negprompt = ''
|
||||
if '###' in prompt:
|
||||
prompt, negprompt = prompt.split('###', 1)
|
||||
prompt = prompt.strip()
|
||||
negprompt = negprompt.strip()
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(self.model)
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(self.model)
|
||||
elif sampler_name == 'k_dpm_2_a':
|
||||
sampler = KDiffusionSampler(self.model,'dpm_2_ancestral')
|
||||
elif sampler_name == 'k_dpm_2':
|
||||
sampler = KDiffusionSampler(self.model,'dpm_2')
|
||||
elif sampler_name == 'k_euler_a':
|
||||
sampler = KDiffusionSampler(self.model,'euler_ancestral')
|
||||
elif sampler_name == 'k_euler':
|
||||
sampler = KDiffusionSampler(self.model,'euler')
|
||||
elif sampler_name == 'k_heun':
|
||||
sampler = KDiffusionSampler(self.model,'heun')
|
||||
elif sampler_name == 'k_lms':
|
||||
sampler = KDiffusionSampler(self.model,'lms')
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=unconditional_conditioning, x_T=x)
|
||||
return samples_ddim
|
||||
|
||||
torch_gc()
|
||||
|
||||
if self.load_concepts and self.concepts_dir is not None:
|
||||
prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
|
||||
if prompt_tokens:
|
||||
self.process_prompt_tokens(prompt_tokens)
|
||||
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
sample_path = os.path.join(self.output_dir, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
|
||||
if self.verify_input:
|
||||
try:
|
||||
check_prompt_length(self.model, prompt, self.comments)
|
||||
except:
|
||||
import traceback
|
||||
print("Error verifying input:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
all_prompts = batch_size * n_iter * [prompt]
|
||||
all_seeds = [seed + x for x in range(len(all_prompts))]
|
||||
|
||||
precision_scope = torch.autocast if self.auto_cast else nullcontext
|
||||
|
||||
with torch.no_grad(), precision_scope("cuda"):
|
||||
for n in range(n_iter):
|
||||
print(f"Iteration: {n+1}/{n_iter}")
|
||||
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
|
||||
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
|
||||
|
||||
uc = self.model.get_learned_conditioning(len(prompts) * [negprompt])
|
||||
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
c = self.model.get_learned_conditioning(prompts)
|
||||
|
||||
opt_C = 4
|
||||
opt_f = 8
|
||||
shape = [opt_C, height // opt_f, width // opt_f]
|
||||
|
||||
x = self.create_random_tensors(shape, seeds=seeds)
|
||||
|
||||
samples_ddim = sample(init_data=None, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
|
||||
|
||||
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
for i, x_sample in enumerate(x_samples_ddim):
|
||||
sanitized_prompt = slugify(prompts[i])
|
||||
full_path = os.path.join(os.getcwd(), sample_path)
|
||||
sample_path_i = sample_path
|
||||
base_count = get_next_sequence_number(sample_path_i)
|
||||
filename = f"{base_count:05}-{ddim_steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:200-len(full_path)]
|
||||
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
image = PIL.Image.fromarray(x_sample)
|
||||
image_dict['image'] = image
|
||||
self.images.append(image_dict)
|
||||
|
||||
if save_individual_images:
|
||||
path = os.path.join(sample_path, filename + '.' + self.save_extension)
|
||||
success = save_sample(image, filename, sample_path_i, self.save_extension)
|
||||
if success:
|
||||
if self.output_file_path:
|
||||
self.output_images.append(path)
|
||||
else:
|
||||
self.output_images.append(image)
|
||||
else:
|
||||
return
|
||||
|
||||
self.info = f"""
|
||||
{prompt}
|
||||
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}
|
||||
""".strip()
|
||||
self.stats = f'''
|
||||
'''
|
||||
|
||||
for comment in self.comments:
|
||||
self.info += "\n\n" + comment
|
||||
|
||||
torch_gc()
|
||||
|
||||
del sampler
|
||||
|
||||
return
|
0
scripts/nataili/inference/diffusers/__init__.py
Normal file
0
scripts/nataili/inference/diffusers/__init__.py
Normal file
458
scripts/nataili/model_manager.py
Normal file
458
scripts/nataili/model_manager.py
Normal file
@ -0,0 +1,458 @@
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import zipfile
|
||||
import requests
|
||||
import git
|
||||
import torch
|
||||
import hashlib
|
||||
from ldm.util import instantiate_from_config
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import logging
|
||||
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from gfpgan import GFPGANer
|
||||
from realesrgan import RealESRGANer
|
||||
from ldm.models.blip import blip_decoder
|
||||
from tqdm import tqdm
|
||||
import open_clip
|
||||
import clip
|
||||
|
||||
from nataili.util.cache import torch_gc
|
||||
from nataili.util import logger
|
||||
|
||||
logging.set_verbosity_error()
|
||||
|
||||
models = json.load(open('./db.json'))
|
||||
dependencies = json.load(open('./db_dep.json'))
|
||||
remote_models = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db.json"
|
||||
remote_dependencies = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db_dep.json"
|
||||
|
||||
class ModelManager():
|
||||
def __init__(self, hf_auth=None, download=True):
|
||||
if download:
|
||||
try:
|
||||
logger.init("Model Reference", status="Downloading")
|
||||
r = requests.get(remote_models)
|
||||
self.models = r.json()
|
||||
r = requests.get(remote_dependencies)
|
||||
self.dependencies = json.load(open('./db_dep.json'))
|
||||
logger.init_ok("Model Reference", status="OK")
|
||||
except:
|
||||
logger.init_err("Model Reference", status="Download Error")
|
||||
self.models = json.load(open('./db.json'))
|
||||
self.dependencies = json.load(open('./db_dep.json'))
|
||||
logger.init_warn("Model Reference", status="Local")
|
||||
self.available_models = []
|
||||
self.tainted_models = []
|
||||
self.available_dependencies = []
|
||||
self.loaded_models = {}
|
||||
self.hf_auth = None
|
||||
self.set_authentication(hf_auth)
|
||||
|
||||
def init(self):
|
||||
dependencies_available = []
|
||||
for dependency in self.dependencies:
|
||||
if self.check_available(self.get_dependency_files(dependency)):
|
||||
dependencies_available.append(dependency)
|
||||
self.available_dependencies = dependencies_available
|
||||
|
||||
models_available = []
|
||||
for model in self.models:
|
||||
if self.check_available(self.get_model_files(model)):
|
||||
models_available.append(model)
|
||||
self.available_models = models_available
|
||||
|
||||
if self.hf_auth is not None:
|
||||
if 'username' not in self.hf_auth and 'password' not in self.hf_auth:
|
||||
raise ValueError('hf_auth must contain username and password')
|
||||
else:
|
||||
if self.hf_auth['username'] == '' or self.hf_auth['password'] == '':
|
||||
raise ValueError('hf_auth must contain username and password')
|
||||
return True
|
||||
|
||||
def set_authentication(self, hf_auth=None):
|
||||
# We do not let No authentication override previously set auth
|
||||
if not hf_auth and self.hf_auth:
|
||||
return
|
||||
self.hf_auth = hf_auth
|
||||
|
||||
def get_model(self, model_name):
|
||||
return self.models.get(model_name)
|
||||
|
||||
def get_filtered_models(self, **kwargs):
|
||||
'''Get all model names.
|
||||
Can filter based on metadata of the model reference db
|
||||
'''
|
||||
filtered_models = self.models
|
||||
for keyword in kwargs:
|
||||
iterating_models = filtered_models.copy()
|
||||
filtered_models = {}
|
||||
for model in iterating_models:
|
||||
# logger.debug([keyword,iterating_models[model].get(keyword),kwargs[keyword]])
|
||||
if iterating_models[model].get(keyword) == kwargs[keyword]:
|
||||
filtered_models[model] = iterating_models[model]
|
||||
return filtered_models
|
||||
|
||||
def get_filtered_model_names(self, **kwargs):
|
||||
filtered_models = self.get_filtered_models(**kwargs)
|
||||
return list(filtered_models.keys())
|
||||
|
||||
def get_dependency(self, dependency_name):
|
||||
return self.dependencies[dependency_name]
|
||||
|
||||
def get_model_files(self, model_name):
|
||||
return self.models[model_name]['config']['files']
|
||||
|
||||
def get_dependency_files(self, dependency_name):
|
||||
return self.dependencies[dependency_name]['config']['files']
|
||||
|
||||
def get_model_download(self, model_name):
|
||||
return self.models[model_name]['config']['download']
|
||||
|
||||
def get_dependency_download(self, dependency_name):
|
||||
return self.dependencies[dependency_name]['config']['download']
|
||||
|
||||
def get_available_models(self):
|
||||
return self.available_models
|
||||
|
||||
def get_available_dependencies(self):
|
||||
return self.available_dependencies
|
||||
|
||||
def get_loaded_models(self):
|
||||
return self.loaded_models
|
||||
|
||||
def get_loaded_models_names(self):
|
||||
return list(self.loaded_models.keys())
|
||||
|
||||
def get_loaded_model(self, model_name):
|
||||
return self.loaded_models[model_name]
|
||||
|
||||
def unload_model(self, model_name):
|
||||
if model_name in self.loaded_models:
|
||||
del self.loaded_models[model_name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def unload_all_models(self):
|
||||
for model in self.loaded_models:
|
||||
del self.loaded_models[model]
|
||||
return True
|
||||
|
||||
def taint_model(self,model_name):
|
||||
'''Marks a model as not valid by remiving it from available_models'''
|
||||
if model_name in self.available_models:
|
||||
self.available_models.remove(model_name)
|
||||
self.tainted_models.append(model_name)
|
||||
|
||||
def taint_models(self, models):
|
||||
for model in models:
|
||||
self.taint_model(model)
|
||||
|
||||
def load_model_from_config(self, model_path='', config_path='', map_location="cpu"):
|
||||
config = OmegaConf.load(config_path)
|
||||
pl_sd = torch.load(model_path, map_location=map_location)
|
||||
if "global_step" in pl_sd:
|
||||
logger.info(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
model = model.eval()
|
||||
del pl_sd, sd, m, u
|
||||
return model
|
||||
|
||||
def load_ckpt(self, model_name='', precision='half', gpu_id=0):
|
||||
ckpt_path = self.get_model_files(model_name)[0]['path']
|
||||
config_path = self.get_model_files(model_name)[1]['path']
|
||||
model = self.load_model_from_config(model_path=ckpt_path, config_path=config_path)
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
model = (model if precision=='full' else model.half()).to(device)
|
||||
torch_gc()
|
||||
return {'model': model, 'device': device}
|
||||
|
||||
def load_realesrgan(self, model_name='', precision='half', gpu_id=0):
|
||||
|
||||
RealESRGAN_models = {
|
||||
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
|
||||
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
}
|
||||
|
||||
model_path = self.get_model_files(model_name)[0]['path']
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
model = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[models[model_name]['name']],
|
||||
pre_pad=0, half=True if precision == 'half' else False, device=device)
|
||||
return {'model': model, 'device': device}
|
||||
|
||||
def load_gfpgan(self, model_name='', gpu_id=0):
|
||||
|
||||
model_path = self.get_model_files(model_name)[0]['path']
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
model = GFPGANer(model_path=model_path, upscale=1, arch='clean',
|
||||
channel_multiplier=2, bg_upsampler=None, device=device)
|
||||
return {'model': model, 'device': device}
|
||||
|
||||
def load_blip(self, model_name='', precision='half', gpu_id=0, blip_image_eval_size=512, vit='base'):
|
||||
# vit = 'base' or 'large'
|
||||
model_path = self.get_model_files(model_name)[0]['path']
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
model = blip_decoder(pretrained=model_path,
|
||||
med_config="configs/blip/med_config.json",
|
||||
image_size=blip_image_eval_size, vit=vit)
|
||||
model = model.eval()
|
||||
model = (model if precision=='full' else model.half()).to(device)
|
||||
return {'model': model, 'device': device}
|
||||
|
||||
def load_open_clip(self, model_name='', precision='half', gpu_id=0):
|
||||
pretrained = self.get_model(model_name)['pretrained_name']
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
model, _, preprocesses = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, cache_dir='models/clip')
|
||||
model = model.eval()
|
||||
model = (model if precision=='full' else model.half()).to(device)
|
||||
return {'model': model, 'device': device, 'preprocesses': preprocesses}
|
||||
|
||||
def load_clip(self, model_name='', precision='half', gpu_id=0):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
model, preprocesses = clip.load(model_name, device=device, download_root='models/clip')
|
||||
model = model.eval()
|
||||
model = (model if precision=='full' else model.half()).to(device)
|
||||
return {'model': model, 'device': device, 'preprocesses': preprocesses}
|
||||
|
||||
def load_model(self, model_name='', precision='half', gpu_id=0):
|
||||
if model_name not in self.available_models:
|
||||
return False
|
||||
if self.models[model_name]['type'] == 'ckpt':
|
||||
self.loaded_models[model_name] = self.load_ckpt(model_name, precision, gpu_id)
|
||||
return True
|
||||
elif self.models[model_name]['type'] == 'realesrgan':
|
||||
self.loaded_models[model_name] = self.load_realesrgan(model_name, precision, gpu_id)
|
||||
return True
|
||||
elif self.models[model_name]['type'] == 'gfpgan':
|
||||
self.loaded_models[model_name] = self.load_gfpgan(model_name, gpu_id)
|
||||
return True
|
||||
elif self.models[model_name]['type'] == 'blip':
|
||||
self.loaded_models[model_name] = self.load_blip(model_name, precision, gpu_id, 512, 'base')
|
||||
return True
|
||||
elif self.models[model_name]['type'] == 'open_clip':
|
||||
self.loaded_models[model_name] = self.load_open_clip(model_name, precision, gpu_id)
|
||||
return True
|
||||
elif self.models[model_name]['type'] == 'clip':
|
||||
self.loaded_models[model_name] = self.load_clip(model_name, precision, gpu_id)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def validate_model(self, model_name):
|
||||
files = self.get_model_files(model_name)
|
||||
all_ok = True
|
||||
for file_details in files:
|
||||
if not self.check_file_available(file_details['path']):
|
||||
return False
|
||||
if not self.validate_file(file_details):
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_file(self, file_details):
|
||||
if 'md5sum' in file_details:
|
||||
file_name = file_details['path']
|
||||
logger.debug(f"Getting md5sum of {file_name}")
|
||||
with open(file_name, 'rb') as file_to_check:
|
||||
file_hash = hashlib.md5()
|
||||
while chunk := file_to_check.read(8192):
|
||||
file_hash.update(chunk)
|
||||
if file_details['md5sum'] != file_hash.hexdigest():
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_file_available(self, file_path):
|
||||
return os.path.exists(file_path)
|
||||
|
||||
def check_available(self, files):
|
||||
available = True
|
||||
for file in files:
|
||||
if not self.check_file_available(file['path']):
|
||||
available = False
|
||||
return available
|
||||
|
||||
def download_file(self, url, file_path):
|
||||
# make directory
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
pbar_desc = file_path.split('/')[-1]
|
||||
r = requests.get(url, stream=True)
|
||||
with open(file_path, 'wb') as f:
|
||||
with tqdm(
|
||||
# all optional kwargs
|
||||
unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
|
||||
desc=pbar_desc, total=int(r.headers.get('content-length', 0))
|
||||
) as pbar:
|
||||
for chunk in r.iter_content(chunk_size=16*1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
def download_model(self, model_name):
|
||||
if model_name in self.available_models:
|
||||
logger.info(f"{model_name} is already available.")
|
||||
return True
|
||||
download = self.get_model_download(model_name)
|
||||
files = self.get_model_files(model_name)
|
||||
for i in range(len(download)):
|
||||
file_path = f"{download[i]['file_path']}/{download[i]['file_name']}" if 'file_path' in download[i] else files[i]['path']
|
||||
|
||||
if 'file_url' in download[i]:
|
||||
download_url = download[i]['file_url']
|
||||
if 'hf_auth' in download[i]:
|
||||
username = self.hf_auth['username']
|
||||
password = self.hf_auth['password']
|
||||
download_url = download_url.format(username=username, password=password)
|
||||
if 'file_name' in download[i]:
|
||||
download_name = download[i]['file_name']
|
||||
if 'file_path' in download[i]:
|
||||
download_path = download[i]['file_path']
|
||||
|
||||
if 'manual' in download[i]:
|
||||
logger.warning(f"The model {model_name} requires manual download from {download_url}. Please place it in {download_path}/{download_name} then press ENTER to continue...")
|
||||
input('')
|
||||
continue
|
||||
# TODO: simplify
|
||||
if "file_content" in download[i]:
|
||||
file_content = download[i]['file_content']
|
||||
logger.info(f"writing {file_content} to {file_path}")
|
||||
# make directory download_path
|
||||
os.makedirs(download_path, exist_ok=True)
|
||||
# write file_content to download_path/download_name
|
||||
with open(os.path.join(download_path, download_name), 'w') as f:
|
||||
f.write(file_content)
|
||||
elif 'symlink' in download[i]:
|
||||
logger.info(f"symlink {file_path} to {download[i]['symlink']}")
|
||||
symlink = download[i]['symlink']
|
||||
# make directory symlink
|
||||
os.makedirs(download_path, exist_ok=True)
|
||||
# make symlink from download_path/download_name to symlink
|
||||
os.symlink(symlink, os.path.join(download_path, download_name))
|
||||
elif 'git' in download[i]:
|
||||
logger.info(f"git clone {download_url} to {file_path}")
|
||||
# make directory download_path
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
git.Git(file_path).clone(download_url)
|
||||
if 'post_process' in download[i]:
|
||||
for post_process in download[i]['post_process']:
|
||||
if 'delete' in post_process:
|
||||
# delete folder post_process['delete']
|
||||
logger.info(f"delete {post_process['delete']}")
|
||||
try:
|
||||
shutil.rmtree(post_process['delete'])
|
||||
except PermissionError as e:
|
||||
logger.error(f"[!] Something went wrong while deleting the `{post_process['delete']}`. Please delete it manually.")
|
||||
logger.error("PermissionError: ", e)
|
||||
else:
|
||||
if not self.check_file_available(file_path) or model_name in self.tainted_models:
|
||||
logger.debug(f'Downloading {download_url} to {file_path}')
|
||||
self.download_file(download_url, file_path)
|
||||
if not self.validate_model(model_name):
|
||||
return False
|
||||
if model_name in self.tainted_models:
|
||||
self.tainted_models.remove(model_name)
|
||||
self.init()
|
||||
return True
|
||||
|
||||
def download_dependency(self, dependency_name):
|
||||
if dependency_name in self.available_dependencies:
|
||||
logger.info(f"{dependency_name} is already installed.")
|
||||
return True
|
||||
download = self.get_dependency_download(dependency_name)
|
||||
files = self.get_dependency_files(dependency_name)
|
||||
for i in range(len(download)):
|
||||
if "git" in download[i]:
|
||||
logger.warning("git download not implemented yet")
|
||||
break
|
||||
|
||||
file_path = files[i]['path']
|
||||
if 'file_url' in download[i]:
|
||||
download_url = download[i]['file_url']
|
||||
if 'file_name' in download[i]:
|
||||
download_name = download[i]['file_name']
|
||||
if 'file_path' in download[i]:
|
||||
download_path = download[i]['file_path']
|
||||
logger.debug(download_name)
|
||||
if "unzip" in download[i]:
|
||||
zip_path = f'temp/{download_name}.zip'
|
||||
# os dirname zip_path
|
||||
# mkdir temp
|
||||
os.makedirs("temp", exist_ok=True)
|
||||
|
||||
self.download_file(download_url, zip_path)
|
||||
logger.info(f"unzip {zip_path}")
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall('temp/')
|
||||
# move temp/sd-concepts-library-main/sd-concepts-library to download_path
|
||||
logger.info(f"move temp/{download_name}-main/{download_name} to {download_path}")
|
||||
shutil.move(f"temp/{download_name}-main/{download_name}", download_path)
|
||||
logger.info(f"delete {zip_path}")
|
||||
os.remove(zip_path)
|
||||
logger.info(f"delete temp/{download_name}-main/")
|
||||
shutil.rmtree(f"temp/{download_name}-main")
|
||||
else:
|
||||
if not self.check_file_available(file_path):
|
||||
logger.init(f'{file_path}', status="Downloading")
|
||||
self.download_file(download_url, file_path)
|
||||
self.init()
|
||||
return True
|
||||
|
||||
def download_all_models(self):
|
||||
for model in self.get_filtered_model_names(download_all = True):
|
||||
if not self.check_model_available(model):
|
||||
logger.init(f"{model}", status="Downloading")
|
||||
self.download_model(model)
|
||||
else:
|
||||
logger.info(f"{model} is already downloaded.")
|
||||
return True
|
||||
|
||||
def download_all_dependencies(self):
|
||||
for dependency in self.dependencies:
|
||||
if not self.check_dependency_available(dependency):
|
||||
logger.init(f"{dependency}",status="Downloading")
|
||||
self.download_dependency(dependency)
|
||||
else:
|
||||
logger.info(f"{dependency} is already installed.")
|
||||
return True
|
||||
|
||||
def download_all(self):
|
||||
self.download_all_dependencies()
|
||||
self.download_all_models()
|
||||
return True
|
||||
|
||||
def check_all_available(self):
|
||||
for model in self.models:
|
||||
if not self.check_available(self.get_model_files(model)):
|
||||
return False
|
||||
for dependency in self.dependencies:
|
||||
if not self.check_available(self.get_dependency_files(dependency)):
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_model_available(self, model_name):
|
||||
if model_name not in self.models:
|
||||
return False
|
||||
return self.check_available(self.get_model_files(model_name))
|
||||
|
||||
def check_dependency_available(self, dependency_name):
|
||||
if dependency_name not in self.dependencies:
|
||||
return False
|
||||
return self.check_available(self.get_dependency_files(dependency_name))
|
||||
|
||||
def check_all_available(self):
|
||||
for model in self.models:
|
||||
if not self.check_model_available(model):
|
||||
return False
|
||||
for dependency in self.dependencies:
|
||||
if not self.check_dependency_available(dependency):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
0
scripts/nataili/postprocess/__init__.py
Normal file
0
scripts/nataili/postprocess/__init__.py
Normal file
48
scripts/nataili/postprocess/upscaler.py
Normal file
48
scripts/nataili/postprocess/upscaler.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Class realesrgan
|
||||
# Inputs:
|
||||
# - model
|
||||
# - device
|
||||
# - output_dir
|
||||
# - output_ext
|
||||
# outupts:
|
||||
# - output_images
|
||||
import PIL
|
||||
from torchvision import transforms
|
||||
import numpy as np
|
||||
import os
|
||||
import cv2
|
||||
|
||||
from nataili.util.save_sample import save_sample
|
||||
|
||||
class realesrgan:
|
||||
def __init__(self, model, device, output_dir, output_ext='jpg'):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.output_dir = output_dir
|
||||
self.output_ext = output_ext
|
||||
self.output_images = []
|
||||
|
||||
def generate(self, input_image):
|
||||
# load image
|
||||
img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
|
||||
if len(img.shape) == 3 and img.shape[2] == 4:
|
||||
img_mode = 'RGBA'
|
||||
else:
|
||||
img_mode = None
|
||||
# upscale
|
||||
output, _ = self.model.enhance(img)
|
||||
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||
self.output_ext = 'png'
|
||||
|
||||
esrgan_sample = output[:,:,::-1]
|
||||
esrgan_image = PIL.Image.fromarray(esrgan_sample)
|
||||
# append model name to output image name
|
||||
filename = os.path.basename(input_image)
|
||||
filename = os.path.splitext(filename)[0]
|
||||
filename = f'{filename}_esrgan'
|
||||
filename_with_ext = f'{filename}.{self.output_ext}'
|
||||
output_image = os.path.join(self.output_dir, filename_with_ext)
|
||||
save_sample(esrgan_image, filename, self.output_dir, self.output_ext)
|
||||
self.output_images.append(output_image)
|
||||
return
|
||||
|
0
scripts/nataili/upscalers/__init__.py
Normal file
0
scripts/nataili/upscalers/__init__.py
Normal file
48
scripts/nataili/upscalers/realesrgan.py
Normal file
48
scripts/nataili/upscalers/realesrgan.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Class realesrgan
|
||||
# Inputs:
|
||||
# - model
|
||||
# - device
|
||||
# - output_dir
|
||||
# - output_ext
|
||||
# outupts:
|
||||
# - output_images
|
||||
import PIL
|
||||
from torchvision import transforms
|
||||
import numpy as np
|
||||
import os
|
||||
import cv2
|
||||
|
||||
from nataili.util.save_sample import save_sample
|
||||
|
||||
class realesrgan:
|
||||
def __init__(self, model, device, output_dir, output_ext='jpg'):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.output_dir = output_dir
|
||||
self.output_ext = output_ext
|
||||
self.output_images = []
|
||||
|
||||
def generate(self, input_image):
|
||||
# load image
|
||||
img = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
|
||||
if len(img.shape) == 3 and img.shape[2] == 4:
|
||||
img_mode = 'RGBA'
|
||||
else:
|
||||
img_mode = None
|
||||
# upscale
|
||||
output, _ = self.model.enhance(img)
|
||||
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
||||
self.output_ext = 'png'
|
||||
|
||||
esrgan_sample = output[:,:,::-1]
|
||||
esrgan_image = PIL.Image.fromarray(esrgan_sample)
|
||||
# append model name to output image name
|
||||
filename = os.path.basename(input_image)
|
||||
filename = os.path.splitext(filename)[0]
|
||||
filename = f'{filename}_esrgan'
|
||||
filename_with_ext = f'{filename}.{self.output_ext}'
|
||||
output_image = os.path.join(self.output_dir, filename_with_ext)
|
||||
save_sample(esrgan_image, filename, self.output_dir, self.output_ext)
|
||||
self.output_images.append(output_image)
|
||||
return
|
||||
|
1
scripts/nataili/util/__init__.py
Normal file
1
scripts/nataili/util/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from nataili.util.logger import logger,set_logger_verbosity, quiesce_logger, test_logger
|
16
scripts/nataili/util/cache.py
Normal file
16
scripts/nataili/util/cache.py
Normal file
@ -0,0 +1,16 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import threading
|
||||
import pynvml
|
||||
import time
|
||||
|
||||
with torch.no_grad():
|
||||
def torch_gc():
|
||||
for _ in range(2):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
18
scripts/nataili/util/check_prompt_length.py
Normal file
18
scripts/nataili/util/check_prompt_length.py
Normal file
@ -0,0 +1,18 @@
|
||||
def check_prompt_length(model, prompt, comments):
|
||||
"""this function tests if prompt is too long, and if so, adds a message to comments"""
|
||||
|
||||
tokenizer = model.cond_stage_model.tokenizer
|
||||
max_length = model.cond_stage_model.max_length
|
||||
|
||||
info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
|
||||
return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
|
||||
ovf = info['overflowing_tokens'][0]
|
||||
overflowing_count = ovf.shape[0]
|
||||
if overflowing_count == 0:
|
||||
return
|
||||
|
||||
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
del tokenizer
|
22
scripts/nataili/util/get_next_sequence_number.py
Normal file
22
scripts/nataili/util/get_next_sequence_number.py
Normal file
@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
|
||||
def get_next_sequence_number(path, prefix=''):
|
||||
"""
|
||||
Determines and returns the next sequence number to use when saving an
|
||||
image in the specified directory.
|
||||
|
||||
If a prefix is given, only consider files whose names start with that
|
||||
prefix, and strip the prefix from filenames before extracting their
|
||||
sequence number.
|
||||
|
||||
The sequence starts at 0.
|
||||
"""
|
||||
result = -1
|
||||
for p in Path(path).iterdir():
|
||||
if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix):
|
||||
tmp = p.name[len(prefix):]
|
||||
try:
|
||||
result = max(int(tmp.split('-')[0]), result)
|
||||
except ValueError:
|
||||
pass
|
||||
return result + 1
|
21
scripts/nataili/util/image_grid.py
Normal file
21
scripts/nataili/util/image_grid.py
Normal file
@ -0,0 +1,21 @@
|
||||
import math
|
||||
|
||||
import PIL
|
||||
|
||||
|
||||
def image_grid(imgs, n_rows=None):
|
||||
if n_rows is not None:
|
||||
rows = n_rows
|
||||
else:
|
||||
rows = math.sqrt(len(imgs))
|
||||
rows = round(rows)
|
||||
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = PIL.Image.new('RGB', size=(cols * w, rows * h), color='black')
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
|
||||
return grid
|
40
scripts/nataili/util/load_learned_embed_in_clip.py
Normal file
40
scripts/nataili/util/load_learned_embed_in_clip.py
Normal file
@ -0,0 +1,40 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
|
||||
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
||||
# separate token and the embeds
|
||||
if learned_embeds_path.endswith('.pt'):
|
||||
# old format
|
||||
# token = * so replace with file directory name when converting
|
||||
trained_token = os.path.basename(learned_embeds_path)
|
||||
params_dict = {
|
||||
trained_token: torch.tensor(list(loaded_learned_embeds['string_to_param'].items())[0][1])
|
||||
}
|
||||
learned_embeds_path = os.path.splitext(learned_embeds_path)[0] + '.bin'
|
||||
torch.save(params_dict, learned_embeds_path)
|
||||
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
||||
trained_token = list(loaded_learned_embeds.keys())[0]
|
||||
embeds = loaded_learned_embeds[trained_token]
|
||||
elif learned_embeds_path.endswith('.bin'):
|
||||
trained_token = list(loaded_learned_embeds.keys())[0]
|
||||
embeds = loaded_learned_embeds[trained_token]
|
||||
|
||||
embeds = loaded_learned_embeds[trained_token]
|
||||
# cast to dtype of text_encoder
|
||||
dtype = text_encoder.get_input_embeddings().weight.dtype
|
||||
embeds.to(dtype)
|
||||
|
||||
# add the token in tokenizer
|
||||
token = token if token is not None else trained_token
|
||||
num_added_tokens = tokenizer.add_tokens(token)
|
||||
|
||||
# resize the token embeddings
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# get the id for the token and assign the embeds
|
||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
||||
return token
|
102
scripts/nataili/util/logger.py
Normal file
102
scripts/nataili/util/logger.py
Normal file
@ -0,0 +1,102 @@
|
||||
import sys
|
||||
from functools import partialmethod
|
||||
from loguru import logger
|
||||
|
||||
STDOUT_LEVELS = ["GENERATION", "PROMPT"]
|
||||
INIT_LEVELS = ["INIT", "INIT_OK", "INIT_WARN", "INIT_ERR"]
|
||||
MESSAGE_LEVELS = ["MESSAGE"]
|
||||
# By default we're at error level or higher
|
||||
verbosity = 20
|
||||
quiet = 0
|
||||
|
||||
def set_logger_verbosity(count):
|
||||
global verbosity
|
||||
# The count comes reversed. So count = 0 means minimum verbosity
|
||||
# While count 5 means maximum verbosity
|
||||
# So the more count we have, the lowe we drop the versbosity maximum
|
||||
verbosity = 20 - (count * 10)
|
||||
|
||||
def quiesce_logger(count):
|
||||
global quiet
|
||||
# The bigger the count, the more silent we want our logger
|
||||
quiet = count * 10
|
||||
|
||||
def is_stdout_log(record):
|
||||
if record["level"].name not in STDOUT_LEVELS:
|
||||
return(False)
|
||||
if record["level"].no < verbosity + quiet:
|
||||
return(False)
|
||||
return(True)
|
||||
|
||||
def is_init_log(record):
|
||||
if record["level"].name not in INIT_LEVELS:
|
||||
return(False)
|
||||
if record["level"].no < verbosity + quiet:
|
||||
return(False)
|
||||
return(True)
|
||||
|
||||
def is_msg_log(record):
|
||||
if record["level"].name not in MESSAGE_LEVELS:
|
||||
return(False)
|
||||
if record["level"].no < verbosity + quiet:
|
||||
return(False)
|
||||
return(True)
|
||||
|
||||
def is_stderr_log(record):
|
||||
if record["level"].name in STDOUT_LEVELS + INIT_LEVELS + MESSAGE_LEVELS:
|
||||
return(False)
|
||||
if record["level"].no < verbosity + quiet:
|
||||
return(False)
|
||||
return(True)
|
||||
|
||||
def test_logger():
|
||||
logger.generation("This is a generation message\nIt is typically multiline\nThee Lines".encode("unicode_escape").decode("utf-8"))
|
||||
logger.prompt("This is a prompt message")
|
||||
logger.debug("Debug Message")
|
||||
logger.info("Info Message")
|
||||
logger.warning("Info Warning")
|
||||
logger.error("Error Message")
|
||||
logger.critical("Critical Message")
|
||||
logger.init("This is an init message", status="Starting")
|
||||
logger.init_ok("This is an init message", status="OK")
|
||||
logger.init_warn("This is an init message", status="Warning")
|
||||
logger.init_err("This is an init message", status="Error")
|
||||
logger.message("This is user message")
|
||||
sys.exit()
|
||||
|
||||
|
||||
logfmt = "<level>{level: <10}</level> | <green>{time:YYYY-MM-DD HH:mm:ss}</green> | <green>{name}</green>:<green>{function}</green>:<green>{line}</green> - <level>{message}</level>"
|
||||
genfmt = "<level>{level: <10}</level> @ <green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{message}</level>"
|
||||
initfmt = "<magenta>INIT </magenta> | <level>{extra[status]: <11}</level> | <magenta>{message}</magenta>"
|
||||
msgfmt = "<level>{level: <10}</level> | <level>{message}</level>"
|
||||
|
||||
try:
|
||||
logger.level("GENERATION", no=24, color="<cyan>")
|
||||
logger.level("PROMPT", no=23, color="<yellow>")
|
||||
logger.level("INIT", no=31, color="<white>")
|
||||
logger.level("INIT_OK", no=31, color="<green>")
|
||||
logger.level("INIT_WARN", no=31, color="<yellow>")
|
||||
logger.level("INIT_ERR", no=31, color="<red>")
|
||||
# Messages contain important information without which this application might not be able to be used
|
||||
# As such, they have the highest priority
|
||||
logger.level("MESSAGE", no=61, color="<green>")
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
logger.__class__.generation = partialmethod(logger.__class__.log, "GENERATION")
|
||||
logger.__class__.prompt = partialmethod(logger.__class__.log, "PROMPT")
|
||||
logger.__class__.init = partialmethod(logger.__class__.log, "INIT")
|
||||
logger.__class__.init_ok = partialmethod(logger.__class__.log, "INIT_OK")
|
||||
logger.__class__.init_warn = partialmethod(logger.__class__.log, "INIT_WARN")
|
||||
logger.__class__.init_err = partialmethod(logger.__class__.log, "INIT_ERR")
|
||||
logger.__class__.message = partialmethod(logger.__class__.log, "MESSAGE")
|
||||
|
||||
config = {
|
||||
"handlers": [
|
||||
{"sink": sys.stderr, "format": logfmt, "colorize":True, "filter": is_stderr_log},
|
||||
{"sink": sys.stdout, "format": genfmt, "level": "PROMPT", "colorize":True, "filter": is_stdout_log},
|
||||
{"sink": sys.stdout, "format": initfmt, "level": "INIT", "colorize":True, "filter": is_init_log},
|
||||
{"sink": sys.stdout, "format": msgfmt, "level": "MESSAGE", "colorize":True, "filter": is_msg_log}
|
||||
],
|
||||
}
|
||||
logger.configure(**config)
|
20
scripts/nataili/util/save_sample.py
Normal file
20
scripts/nataili/util/save_sample.py
Normal file
@ -0,0 +1,20 @@
|
||||
import os
|
||||
|
||||
def save_sample(image, filename, sample_path, extension='png', jpg_quality=95, webp_quality=95, webp_lossless=True, png_compression=9):
|
||||
path = os.path.join(sample_path, filename + '.' + extension)
|
||||
if os.path.exists(path):
|
||||
return False
|
||||
if not os.path.exists(sample_path):
|
||||
os.makedirs(sample_path)
|
||||
if extension == 'png':
|
||||
image.save(path, format='PNG', compress_level=png_compression)
|
||||
elif extension == 'jpg':
|
||||
image.save(path, quality=jpg_quality, optimize=True)
|
||||
elif extension == 'webp':
|
||||
image.save(path, quality=webp_quality, lossless=webp_lossless)
|
||||
else:
|
||||
return False
|
||||
if os.path.exists(path):
|
||||
return True
|
||||
else:
|
||||
return False
|
22
scripts/nataili/util/seed_to_int.py
Normal file
22
scripts/nataili/util/seed_to_int.py
Normal file
@ -0,0 +1,22 @@
|
||||
import random
|
||||
|
||||
def seed_to_int(s):
|
||||
if type(s) is int:
|
||||
return s
|
||||
if s is None or s == '':
|
||||
return random.randint(0, 2**32 - 1)
|
||||
|
||||
if type(s) is list:
|
||||
seed_list = []
|
||||
for seed in s:
|
||||
if seed is None or seed == '':
|
||||
seed_list.append(random.randint(0, 2**32 - 1))
|
||||
else:
|
||||
seed_list = s
|
||||
|
||||
return seed_list
|
||||
|
||||
n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1))
|
||||
while n >= 2**32:
|
||||
n = n >> 32
|
||||
return n
|
Loading…
Reference in New Issue
Block a user