mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-15 15:13:45 +03:00
I blame code autocomplete
This commit is contained in:
parent
0abb39f461
commit
0d07cbfa15
@ -33,12 +33,9 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
"tanh": torch.nn.Tanh,
|
"tanh": torch.nn.Tanh,
|
||||||
"sigmoid": torch.nn.Sigmoid,
|
"sigmoid": torch.nn.Sigmoid,
|
||||||
}
|
}
|
||||||
activation_dict.update(
|
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||||
{cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if
|
|
||||||
inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
|
||||||
add_layer_norm=False, use_dropout=False):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert layer_structure is not None, "layer_structure must not be None"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
@ -49,7 +46,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
for i in range(len(layer_structure) - 1):
|
for i in range(len(layer_structure) - 1):
|
||||||
|
|
||||||
# Add a fully-connected layer
|
# Add a fully-connected layer
|
||||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i + 1])))
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
# Add an activation func
|
# Add an activation func
|
||||||
if activation_func == "linear" or activation_func is None:
|
if activation_func == "linear" or activation_func is None:
|
||||||
@ -61,7 +58,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
|
|
||||||
# Add layer normalization
|
# Add layer normalization
|
||||||
if add_layer_norm:
|
if add_layer_norm:
|
||||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i + 1])))
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
# Add dropout expect last layer
|
# Add dropout expect last layer
|
||||||
if use_dropout and i < len(layer_structure) - 3:
|
if use_dropout and i < len(layer_structure) - 3:
|
||||||
@ -130,8 +127,7 @@ class Hypernetwork:
|
|||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None,
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||||
add_layer_norm=False, use_dropout=False):
|
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
@ -146,10 +142,8 @@ class Hypernetwork:
|
|||||||
|
|
||||||
for size in enable_sizes or []:
|
for size in enable_sizes or []:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||||
self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
|
||||||
self.add_layer_norm, self.use_dropout),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
@ -196,15 +190,13 @@ class Hypernetwork:
|
|||||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
print(f"Layer norm is set to {self.add_layer_norm}")
|
||||||
self.use_dropout = state_dict.get('use_dropout', False)
|
self.use_dropout = state_dict.get('use_dropout', False)
|
||||||
print(f"Dropout usage is set to {self.use_dropout}")
|
print(f"Dropout usage is set to {self.use_dropout}" )
|
||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||||
self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
|
||||||
self.add_layer_norm, self.use_dropout),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = state_dict.get('name', self.name)
|
self.name = state_dict.get('name', self.name)
|
||||||
@ -316,7 +308,7 @@ def statistics(data):
|
|||||||
std = 0
|
std = 0
|
||||||
else:
|
else:
|
||||||
std = stdev(data)
|
std = stdev(data)
|
||||||
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std / (len(data) ** 0.5):.3f})"
|
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
|
||||||
recent_data = data[-32:]
|
recent_data = data[-32:]
|
||||||
if len(recent_data) < 2:
|
if len(recent_data) < 2:
|
||||||
std = 0
|
std = 0
|
||||||
@ -326,7 +318,7 @@ def statistics(data):
|
|||||||
return total_information, recent_information
|
return total_information, recent_information
|
||||||
|
|
||||||
|
|
||||||
def report_statistics(loss_info: dict):
|
def report_statistics(loss_info:dict):
|
||||||
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
||||||
for key in keys:
|
for key in keys:
|
||||||
try:
|
try:
|
||||||
@ -338,18 +330,14 @@ def report_statistics(loss_info: dict):
|
|||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
|
|
||||||
training_height, steps, create_image_every, save_hypernetwork_every, template_file,
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
|
|
||||||
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
from modules import images
|
from modules import images
|
||||||
|
|
||||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps,
|
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||||
save_hypernetwork_every, create_image_every, log_directory,
|
|
||||||
name="hypernetwork")
|
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
shared.loaded_hypernetwork = Hypernetwork()
|
shared.loaded_hypernetwork = Hypernetwork()
|
||||||
@ -384,29 +372,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||||
|
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||||
height=training_height,
|
|
||||||
repeats=shared.opts.training_image_repeats_per_epoch,
|
|
||||||
placeholder_token=hypernetwork_name,
|
|
||||||
model=shared.sd_model, device=devices.device,
|
|
||||||
template_file=template_file, include_cond=True,
|
|
||||||
batch_size=batch_size)
|
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
size = len(ds.indexes)
|
size = len(ds.indexes)
|
||||||
loss_dict = defaultdict(lambda: deque(maxlen=1024))
|
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||||
losses = torch.zeros((size,))
|
losses = torch.zeros((size,))
|
||||||
previous_mean_losses = [0]
|
previous_mean_losses = [0]
|
||||||
previous_mean_loss = 0
|
previous_mean_loss = 0
|
||||||
print("Mean loss of {} elements".format(size))
|
print("Mean loss of {} elements".format(size))
|
||||||
|
|
||||||
weights = hypernetwork.weights()
|
weights = hypernetwork.weights()
|
||||||
for weight in weights:
|
for weight in weights:
|
||||||
weight.requires_grad = True
|
weight.requires_grad = True
|
||||||
@ -425,7 +407,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
if len(loss_dict) > 0:
|
if len(loss_dict) > 0:
|
||||||
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
||||||
previous_mean_loss = mean(previous_mean_losses)
|
previous_mean_loss = mean(previous_mean_losses)
|
||||||
|
|
||||||
scheduler.apply(optimizer, hypernetwork.step)
|
scheduler.apply(optimizer, hypernetwork.step)
|
||||||
if scheduler.finished:
|
if scheduler.finished:
|
||||||
break
|
break
|
||||||
@ -444,7 +426,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
loss_dict[entry.filename].append(loss.item())
|
loss_dict[entry.filename].append(loss.item())
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
weights[0].grad = None
|
weights[0].grad = None
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -459,9 +441,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
|
|
||||||
steps_done = hypernetwork.step + 1
|
steps_done = hypernetwork.step + 1
|
||||||
|
|
||||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||||
raise RuntimeError("Loss diverged.")
|
raise RuntimeError("Loss diverged.")
|
||||||
|
|
||||||
if len(previous_mean_losses) > 1:
|
if len(previous_mean_losses) > 1:
|
||||||
std = stdev(previous_mean_losses)
|
std = stdev(previous_mean_losses)
|
||||||
else:
|
else:
|
||||||
@ -510,7 +492,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
preview_text = p.prompt
|
preview_text = p.prompt
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
image = processed.images[0] if len(processed.images) > 0 else None
|
image = processed.images[0] if len(processed.images)>0 else None
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
@ -518,10 +500,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
shared.opts.samples_format, processed.infotexts[0],
|
|
||||||
p=p, forced_filename=forced_filename,
|
|
||||||
save_to_dirs=False)
|
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
shared.state.job_no = hypernetwork.step
|
||||||
@ -535,7 +514,7 @@ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
|||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
report_statistics(loss_dict)
|
report_statistics(loss_dict)
|
||||||
|
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
@ -543,7 +522,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
|
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
|
||||||
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
||||||
old_hypernetwork_name = hypernetwork.name
|
old_hypernetwork_name = hypernetwork.name
|
||||||
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
||||||
@ -557,4 +535,4 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
|||||||
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
||||||
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
||||||
hypernetwork.name = old_hypernetwork_name
|
hypernetwork.name = old_hypernetwork_name
|
||||||
raise
|
raise
|
||||||
|
Loading…
Reference in New Issue
Block a user