diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3ce85bb5..dd317085 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -28,7 +28,7 @@ class HypernetworkModule(torch.nn.Module): "swish": torch.nn.Hardswish, } - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): + def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -42,7 +42,7 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # Add an activation func except last layer - if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 2: + if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output): pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) @@ -105,7 +105,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False): self.filename = None self.name = name self.layers = {} @@ -116,11 +116,12 @@ class Hypernetwork: self.activation_func = activation_func self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout + self.activate_output = activate_output for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), ) def weights(self): @@ -147,6 +148,7 @@ class Hypernetwork: state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name + state_dict['activate_output'] = self.activate_output torch.save(state_dict, filename) @@ -161,12 +163,13 @@ class Hypernetwork: self.activation_func = state_dict.get('activation_func', None) self.add_layer_norm = state_dict.get('is_layer_norm', False) self.use_dropout = state_dict.get('use_dropout', False) + self.activate_output = state_dict.get('activate_output', True) for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), + HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output), ) self.name = state_dict.get('name', self.name)