ruff manual fixes

This commit is contained in:
AUTOMATIC 2023-05-10 11:19:16 +03:00
parent 028d3f6425
commit 550256db1c
15 changed files with 69 additions and 41 deletions

View File

@ -24,7 +24,7 @@ class VQModel(pl.LightningModule):
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
@ -62,7 +62,7 @@ class VQModel(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@ -81,11 +81,11 @@ class VQModel(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
@ -270,7 +270,7 @@ class VQModel(pl.LightningModule):
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):

View File

@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
ignore_keys=None,
load_only_unet=False,
monitor="val/loss",
use_ema=True,
@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
@ -444,7 +444,7 @@ class LatentDiffusionV1(DDPMV1):
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
@ -1418,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]

View File

@ -644,13 +644,17 @@ class SwinIR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
embed_dim=96, depths=None, num_heads=None,
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(SwinIR, self).__init__()
depths = depths or [6, 6, 6, 6]
num_heads = num_heads or [6, 6, 6, 6]
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64

View File

@ -74,9 +74,12 @@ class WindowAttention(nn.Module):
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]):
pretrained_window_size=None):
super().__init__()
pretrained_window_size = pretrained_window_size or [0, 0]
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
@ -698,13 +701,17 @@ class Swin2SR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
embed_dim=96, depths=None, num_heads=None,
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(Swin2SR, self).__init__()
depths = depths or [6, 6, 6, 6]
num_heads = num_heads or [6, 6, 6, 6]
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64

View File

@ -34,14 +34,16 @@ import piexif.helper
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except Exception:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
def script_name_to_index(name, scripts):
try:
return [script.title().lower() for script in scripts].index(name.lower())
except Exception:
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
except Exception as e:
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
@ -50,20 +52,23 @@ def validate_sampler_name(name):
return name
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
return reqDict
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
except Exception:
raise HTTPException(status_code=500, detail="Invalid encoded image")
except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
@ -94,6 +99,7 @@ def encode_pil_to_base64(image):
return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI):
rich_available = True
try:

View File

@ -161,10 +161,13 @@ class Fuse_sft_block(nn.Module):
class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256,
connect_list=['32', '64', '128', '256'],
fix_modules=['quantize','generator']):
connect_list=None,
fix_modules=None):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
connect_list = connect_list or ['32', '64', '128', '256']
fix_modules = fix_modules or ['quantize', 'generator']
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():

View File

@ -326,7 +326,7 @@ class Generator(nn.Module):
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
@ -337,7 +337,7 @@ class VQAutoEncoder(nn.Module):
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.attn_resolutions = attn_resolutions or [16]
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,

View File

@ -19,14 +19,14 @@ registered_param_bindings = []
class ParamBinding:
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
self.paste_button = paste_button
self.tabname = tabname
self.source_text_component = source_text_component
self.source_image_component = source_image_component
self.source_tabname = source_tabname
self.override_settings_component = override_settings_component
self.paste_field_names = paste_field_names
self.paste_field_names = paste_field_names or []
def reset():

View File

@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
ignore_keys=None,
load_only_unet=False,
monitor="val/loss",
use_ema=True,
@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
# If initialing from EMA-only checkpoint, create EMA model after loading.
if self.use_ema and not load_ema:
@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
ignore_keys = ignore_keys or []
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
@ -473,7 +475,7 @@ class LatentDiffusion(DDPM):
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
@ -1433,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]

View File

@ -178,13 +178,13 @@ def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs={},
model_kwargs=None,
guidance_type="uncond",
#condition=None,
#unconditional_condition=None,
guidance_scale=1.,
classifier_fn=None,
classifier_kwargs={},
classifier_kwargs=None,
):
"""Create a wrapper function for the noise prediction model.
@ -275,6 +275,9 @@ def model_wrapper(
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
model_kwargs = model_kwargs or []
classifier_kwargs = classifier_kwargs or []
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.

View File

@ -104,7 +104,7 @@ def check_pt(filename, extra_handler):
def load(filename, *args, **kwargs):
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):

View File

@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)

View File

@ -17,7 +17,7 @@ class EmbeddingEncoder(json.JSONEncoder):
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
def object_hook(self, d):
if 'TORCHTENSOR' in d:

View File

@ -32,8 +32,8 @@ class LearnScheduleIterator:
self.maxit += 1
return
assert self.rates
except (ValueError, AssertionError):
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
except (ValueError, AssertionError) as e:
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
def __iter__(self):

View File

@ -24,6 +24,9 @@ ignore = [
]
[tool.ruff.per-file-ignores]
"webui.py" = ["E402"] # Module level import not at top of file
[tool.ruff.flake8-bugbear]
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]