fix for #3086 failing to load any previous hypernet
This commit is contained in:
parent
c664b231a8
commit
2ce52d32e4
@ -24,11 +24,10 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
|
|
||||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
|
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if layer_structure is not None:
|
|
||||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
assert layer_structure is not None, "layer_structure mut not be None"
|
||||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
else:
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
layer_structure = parse_layer_structure(dim, state_dict)
|
|
||||||
|
|
||||||
linears = []
|
linears = []
|
||||||
for i in range(len(layer_structure) - 1):
|
for i in range(len(layer_structure) - 1):
|
||||||
@ -39,23 +38,30 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
self.linear = torch.nn.Sequential(*linears)
|
self.linear = torch.nn.Sequential(*linears)
|
||||||
|
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
try:
|
self.fix_old_state_dict(state_dict)
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
except RuntimeError:
|
|
||||||
self.try_load_previous(state_dict)
|
|
||||||
else:
|
else:
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
layer.weight.data.normal_(mean = 0.0, std = 0.01)
|
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
layer.bias.data.zero_()
|
layer.bias.data.zero_()
|
||||||
|
|
||||||
self.to(devices.device)
|
self.to(devices.device)
|
||||||
|
|
||||||
def try_load_previous(self, state_dict):
|
def fix_old_state_dict(self, state_dict):
|
||||||
states = self.state_dict()
|
changes = {
|
||||||
states['linear.0.bias'].copy_(state_dict['linear1.bias'])
|
'linear1.bias': 'linear.0.bias',
|
||||||
states['linear.0.weight'].copy_(state_dict['linear1.weight'])
|
'linear1.weight': 'linear.0.weight',
|
||||||
states['linear.1.bias'].copy_(state_dict['linear2.bias'])
|
'linear2.bias': 'linear.1.bias',
|
||||||
states['linear.1.weight'].copy_(state_dict['linear2.weight'])
|
'linear2.weight': 'linear.1.weight',
|
||||||
|
}
|
||||||
|
|
||||||
|
for fr, to in changes.items():
|
||||||
|
x = state_dict.get(fr, None)
|
||||||
|
if x is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
del state_dict[fr]
|
||||||
|
state_dict[to] = x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + self.linear(x) * self.multiplier
|
return x + self.linear(x) * self.multiplier
|
||||||
@ -71,18 +77,6 @@ def apply_strength(value=None):
|
|||||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||||
|
|
||||||
|
|
||||||
def parse_layer_structure(dim, state_dict):
|
|
||||||
i = 0
|
|
||||||
layer_structure = [1]
|
|
||||||
|
|
||||||
while (key := "linear.{}.weight".format(i)) in state_dict:
|
|
||||||
weight = state_dict[key]
|
|
||||||
layer_structure.append(len(weight) // dim)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return layer_structure
|
|
||||||
|
|
||||||
|
|
||||||
class Hypernetwork:
|
class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
@ -135,17 +129,18 @@ class Hypernetwork:
|
|||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu')
|
state_dict = torch.load(filename, map_location='cpu')
|
||||||
|
|
||||||
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||||
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
|
||||||
HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = state_dict.get('name', self.name)
|
self.name = state_dict.get('name', self.name)
|
||||||
self.step = state_dict.get('step', 0)
|
self.step = state_dict.get('step', 0)
|
||||||
self.layer_structure = state_dict.get('layer_structure', None)
|
|
||||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
|
||||||
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
||||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||||
|
|
||||||
@ -244,6 +239,7 @@ def stack_conds(conds):
|
|||||||
|
|
||||||
return torch.stack(conds)
|
return torch.stack(conds)
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
assert hypernetwork_name, 'hypernetwork not selected'
|
assert hypernetwork_name, 'hypernetwork not selected'
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user