Variable dropout rate

Implements variable dropout rate from #4549

Fixes hypernetwork multiplier being able to modified during training, also fixes user-errors by setting multiplier value to lower values for training.

Changes function name to match torch.nn.module standard

Fixes RNG reset issue when generating previews by restoring RNG state
This commit is contained in:
aria1th 2023-01-10 14:56:57 +09:00
parent bd4587d2f5
commit a4a5475cfa
3 changed files with 81 additions and 28 deletions

View File

@ -39,7 +39,7 @@ class HypernetworkModule(torch.nn.Module):
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False): add_layer_norm=False, activate_output=False, dropout_structure=None):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
@ -64,9 +64,12 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm: if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout except last layer # Everything should be now parsed into dropout structure, and applied here.
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2): # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
linears.append(torch.nn.Dropout(p=0.3)) if dropout_structure is not None and dropout_structure[i+1] > 0:
assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
# Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
self.linear = torch.nn.Sequential(*linears) self.linear = torch.nn.Sequential(*linears)
@ -113,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
state_dict[to] = x state_dict[to] = x
def forward(self, x): def forward(self, x):
return x + self.linear(x) * self.multiplier return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
def trainables(self): def trainables(self):
layer_structure = [] layer_structure = []
@ -126,6 +129,21 @@ class HypernetworkModule(torch.nn.Module):
def apply_strength(value=None): def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
if layer_structure is None:
layer_structure = [1, 2, 1]
if not use_dropout:
return [0] * len(layer_structure)
dropout_values = [0]
dropout_values.extend([0.3] * (len(layer_structure) - 3))
if last_layer_dropout:
dropout_values.append(0.3)
else:
dropout_values.append(0)
dropout_values.append(0)
return dropout_values
class Hypernetwork: class Hypernetwork:
filename = None filename = None
@ -144,18 +162,22 @@ class Hypernetwork:
self.add_layer_norm = add_layer_norm self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout self.use_dropout = use_dropout
self.activate_output = activate_output self.activate_output = activate_output
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
self.dropout_structure = kwargs.get('dropout_structure', None)
if self.dropout_structure is None:
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
self.optimizer_name = None self.optimizer_name = None
self.optimizer_state_dict = None self.optimizer_state_dict = None
self.optional_info = None
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
) )
self.eval_mode() self.eval()
def weights(self): def weights(self):
res = [] res = []
@ -164,14 +186,14 @@ class Hypernetwork:
res += layer.parameters() res += layer.parameters()
return res return res
def train_mode(self): def train(self, mode=True):
for k, layers in self.layers.items(): for k, layers in self.layers.items():
for layer in layers: for layer in layers:
layer.train() layer.train(mode=mode)
for param in layer.parameters(): for param in layer.parameters():
param.requires_grad = True param.requires_grad = mode
def eval_mode(self): def eval(self):
for k, layers in self.layers.items(): for k, layers in self.layers.items():
for layer in layers: for layer in layers:
layer.eval() layer.eval()
@ -191,11 +213,13 @@ class Hypernetwork:
state_dict['activation_func'] = self.activation_func state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['weight_initialization'] = self.weight_init state_dict['weight_initialization'] = self.weight_init
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output state_dict['activate_output'] = self.activate_output
state_dict['last_layer_dropout'] = self.last_layer_dropout state_dict['use_dropout'] = self.use_dropout
state_dict['dropout_structure'] = self.dropout_structure
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
state_dict['optional_info'] = self.optional_info if self.optional_info else None
if self.optimizer_name is not None: if self.optimizer_name is not None:
optimizer_saved_dict['optimizer_name'] = self.optimizer_name optimizer_saved_dict['optimizer_name'] = self.optimizer_name
@ -215,43 +239,56 @@ class Hypernetwork:
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
print(self.layer_structure) print(self.layer_structure)
optional_info = state_dict.get('optional_info', None)
if optional_info is not None:
print(f"INFO:\n {optional_info}\n")
self.optional_info = optional_info
self.activation_func = state_dict.get('activation_func', None) self.activation_func = state_dict.get('activation_func', None)
print(f"Activation function is {self.activation_func}") print(f"Activation function is {self.activation_func}")
self.weight_init = state_dict.get('weight_initialization', 'Normal') self.weight_init = state_dict.get('weight_initialization', 'Normal')
print(f"Weight initialization is {self.weight_init}") print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False) self.add_layer_norm = state_dict.get('is_layer_norm', False)
print(f"Layer norm is set to {self.add_layer_norm}") print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False) self.dropout_structure = state_dict.get('dropout_structure', None)
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" ) print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True) self.activate_output = state_dict.get('activate_output', True)
print(f"Activate last layer is set to {self.activate_output}") print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False) self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
if self.dropout_structure is None:
print("Using previous dropout structure")
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
print(f"Dropout structure is set to {self.dropout_structure}")
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print(f"Optimizer name is {self.optimizer_name}")
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else: else:
self.optimizer_state_dict = None self.optimizer_state_dict = None
if self.optimizer_state_dict: if self.optimizer_state_dict:
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print("Loaded existing optimizer from checkpoint") print("Loaded existing optimizer from checkpoint")
print(f"Optimizer name is {self.optimizer_name}")
else: else:
self.optimizer_name = "AdamW"
print("No saved optimizer exists in checkpoint") print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), self.add_layer_norm, self.activate_output, self.dropout_structure),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), self.add_layer_norm, self.activate_output, self.dropout_structure),
) )
self.name = state_dict.get('name', self.name) self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0) self.step = state_dict.get('step', 0)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
self.eval()
def list_hypernetworks(path): def list_hypernetworks(path):
@ -379,9 +416,10 @@ def report_statistics(loss_info:dict):
print(e) print(e)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name. # Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- ")) name = "".join( x for x in name if (x.isalnum() or x in "._- "))
assert name, "Name cannot be empty!"
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
if not overwrite_old: if not overwrite_old:
@ -390,6 +428,11 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
if type(layer_structure) == str: if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")] layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
if use_dropout and dropout_structure and type(dropout_structure) == str:
dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
else:
dropout_structure = [0] * len(layer_structure)
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name, name=name,
enable_sizes=[int(x) for x in enable_sizes], enable_sizes=[int(x) for x in enable_sizes],
@ -398,6 +441,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
weight_init=weight_init, weight_init=weight_init,
add_layer_norm=add_layer_norm, add_layer_norm=add_layer_norm,
use_dropout=use_dropout, use_dropout=use_dropout,
dropout_structure=dropout_structure
) )
hypernet.save(fn) hypernet.save(fn)
@ -480,7 +524,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
weights = hypernetwork.weights() weights = hypernetwork.weights()
hypernetwork.train_mode() hypernetwork.train()
# Here we use optimizer from saved HN, or we can specify as UI option. # Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict: if hypernetwork.optimizer_name in optimizer_dict:
@ -594,7 +638,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if images_dir is not None and steps_done % create_image_every == 0: if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}' forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename) last_saved_image = os.path.join(images_dir, forced_filename)
hypernetwork.eval_mode() hypernetwork.eval()
rng_state = torch.get_rng_state()
cuda_rng_state = None
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state_all()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
@ -627,7 +675,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork.train_mode() torch.set_rng_state(rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state_all(cuda_rng_state)
hypernetwork.train()
if image is not None: if image is not None:
shared.state.current_image = image shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
@ -649,7 +700,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
finally: finally:
pbar.leave = False pbar.leave = False
pbar.close() pbar.close()
hypernetwork.eval_mode() hypernetwork.eval()
#report_statistics(loss_dict) #report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')

View File

@ -9,8 +9,8 @@ from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"] not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout) filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", "" return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""

View File

@ -1268,6 +1268,7 @@ def create_ui():
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
with gr.Row(): with gr.Row():
@ -1414,7 +1415,8 @@ def create_ui():
new_hypernetwork_activation_func, new_hypernetwork_activation_func,
new_hypernetwork_initialization_option, new_hypernetwork_initialization_option,
new_hypernetwork_add_layer_norm, new_hypernetwork_add_layer_norm,
new_hypernetwork_use_dropout new_hypernetwork_use_dropout,
new_hypernetwork_dropout_structure
], ],
outputs=[ outputs=[
train_hypernetwork_name, train_hypernetwork_name,