Merge pull request #3086 from discus0434/master
Add settings for multi-layer structure hypernetworks
This commit is contained in:
commit
f510a2277e
@ -22,45 +22,86 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
|
||||
def __init__(self, dim, state_dict=None):
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
|
||||
super().__init__()
|
||||
if layer_structure is not None:
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
else:
|
||||
layer_structure = parse_layer_structure(dim, state_dict)
|
||||
|
||||
self.linear1 = torch.nn.Linear(dim, dim * 2)
|
||||
self.linear2 = torch.nn.Linear(dim * 2, dim)
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
if state_dict is not None:
|
||||
self.load_state_dict(state_dict, strict=True)
|
||||
try:
|
||||
self.load_state_dict(state_dict)
|
||||
except RuntimeError:
|
||||
self.try_load_previous(state_dict)
|
||||
else:
|
||||
|
||||
self.linear1.weight.data.normal_(mean=0.0, std=0.01)
|
||||
self.linear1.bias.data.zero_()
|
||||
self.linear2.weight.data.normal_(mean=0.0, std=0.01)
|
||||
self.linear2.bias.data.zero_()
|
||||
for layer in self.linear:
|
||||
layer.weight.data.normal_(mean = 0.0, std = 0.01)
|
||||
layer.bias.data.zero_()
|
||||
|
||||
self.to(devices.device)
|
||||
|
||||
def try_load_previous(self, state_dict):
|
||||
states = self.state_dict()
|
||||
states['linear.0.bias'].copy_(state_dict['linear1.bias'])
|
||||
states['linear.0.weight'].copy_(state_dict['linear1.weight'])
|
||||
states['linear.1.bias'].copy_(state_dict['linear2.bias'])
|
||||
states['linear.1.weight'].copy_(state_dict['linear2.weight'])
|
||||
|
||||
def forward(self, x):
|
||||
return x + (self.linear2(self.linear1(x))) * self.multiplier
|
||||
return x + self.linear(x) * self.multiplier
|
||||
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
for layer in self.linear:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
return layer_structure
|
||||
|
||||
|
||||
def apply_strength(value=None):
|
||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||
|
||||
|
||||
def parse_layer_structure(dim, state_dict):
|
||||
i = 0
|
||||
layer_structure = [1]
|
||||
|
||||
while (key := "linear.{}.weight".format(i)) in state_dict:
|
||||
weight = state_dict[key]
|
||||
layer_structure.append(len(weight) // dim)
|
||||
i += 1
|
||||
|
||||
return layer_structure
|
||||
|
||||
|
||||
class Hypernetwork:
|
||||
filename = None
|
||||
name = None
|
||||
|
||||
def __init__(self, name=None, enable_sizes=None):
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
|
||||
self.filename = None
|
||||
self.name = name
|
||||
self.layers = {}
|
||||
self.step = 0
|
||||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.layer_structure = layer_structure
|
||||
self.add_layer_norm = add_layer_norm
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
||||
)
|
||||
|
||||
def weights(self):
|
||||
res = []
|
||||
@ -68,7 +109,7 @@ class Hypernetwork:
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.train()
|
||||
res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
|
||||
res += layer.trainables()
|
||||
|
||||
return res
|
||||
|
||||
@ -80,6 +121,8 @@ class Hypernetwork:
|
||||
|
||||
state_dict['step'] = self.step
|
||||
state_dict['name'] = self.name
|
||||
state_dict['layer_structure'] = self.layer_structure
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
|
||||
@ -94,10 +137,15 @@ class Hypernetwork:
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
||||
HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
self.step = state_dict.get('step', 0)
|
||||
self.layer_structure = state_dict.get('layer_structure', None)
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||
|
||||
@ -226,7 +274,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
@ -261,7 +308,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||
loss = shared.sd_model(x, c)[0]
|
||||
del x
|
||||
|
@ -9,11 +9,16 @@ from modules import sd_hijack, shared, devices
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes):
|
||||
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
|
||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
|
||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||
name=name,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
layer_structure=layer_structure,
|
||||
add_layer_norm=add_layer_norm,
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
@ -137,7 +137,7 @@ class State:
|
||||
self.job_no += 1
|
||||
self.sampling_step = 0
|
||||
self.current_image_sampling_step = 0
|
||||
|
||||
|
||||
def get_job_timestamp(self):
|
||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
||||
|
||||
|
@ -477,14 +477,14 @@ def create_toprow(is_img2img):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
|
||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
|
||||
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
|
||||
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||
)
|
||||
|
||||
@ -1217,6 +1217,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
with gr.Tab(label="Create hypernetwork"):
|
||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||
new_hypernetwork_layer_structure = gr.Dropdown(label="Hypernetwork layer structure", choices=[(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)])
|
||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
@ -1299,6 +1301,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
inputs=[
|
||||
new_hypernetwork_name,
|
||||
new_hypernetwork_sizes,
|
||||
new_hypernetwork_layer_structure,
|
||||
new_hypernetwork_add_layer_norm,
|
||||
],
|
||||
outputs=[
|
||||
train_hypernetwork_name,
|
||||
|
Loading…
Reference in New Issue
Block a user