layer options moves into create hnet ui
This commit is contained in:
parent
7f8670c4ef
commit
42fbda83bb
@ -19,37 +19,21 @@ from modules.textual_inversion import textual_inversion
|
|||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|
||||||
|
|
||||||
def parse_layer_structure(dim, state_dict):
|
|
||||||
i = 0
|
|
||||||
res = [1]
|
|
||||||
while (key := "linear.{}.weight".format(i)) in state_dict:
|
|
||||||
weight = state_dict[key]
|
|
||||||
res.append(len(weight) // dim)
|
|
||||||
i += 1
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
layer_structure = None
|
|
||||||
add_layer_norm = False
|
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None):
|
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None:
|
if layer_structure is not None:
|
||||||
layer_structure = (1, 2, 1)
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
else:
|
else:
|
||||||
if self.layer_structure is not None:
|
layer_structure = parse_layer_structure(dim, state_dict)
|
||||||
assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
|
||||||
assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
|
||||||
layer_structure = self.layer_structure
|
|
||||||
else:
|
|
||||||
layer_structure = parse_layer_structure(dim, state_dict)
|
|
||||||
|
|
||||||
linears = []
|
linears = []
|
||||||
for i in range(len(layer_structure) - 1):
|
for i in range(len(layer_structure) - 1):
|
||||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
if self.add_layer_norm:
|
if add_layer_norm:
|
||||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
self.linear = torch.nn.Sequential(*linears)
|
self.linear = torch.nn.Sequential(*linears)
|
||||||
@ -77,38 +61,47 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
return x + self.linear(x) * self.multiplier
|
return x + self.linear(x) * self.multiplier
|
||||||
|
|
||||||
def trainables(self):
|
def trainables(self):
|
||||||
res = []
|
layer_structure = []
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
res += [layer.weight, layer.bias]
|
layer_structure += [layer.weight, layer.bias]
|
||||||
return res
|
return layer_structure
|
||||||
|
|
||||||
|
|
||||||
def apply_strength(value=None):
|
def apply_strength(value=None):
|
||||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||||
|
|
||||||
|
|
||||||
def apply_layer_structure(value=None):
|
def parse_layer_structure(dim, state_dict):
|
||||||
HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure
|
i = 0
|
||||||
|
layer_structure = [1]
|
||||||
|
|
||||||
|
while (key := "linear.{}.weight".format(i)) in state_dict:
|
||||||
|
weight = state_dict[key]
|
||||||
|
layer_structure.append(len(weight) // dim)
|
||||||
|
i += 1
|
||||||
|
|
||||||
def apply_layer_norm(value=None):
|
return layer_structure
|
||||||
HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm
|
|
||||||
|
|
||||||
|
|
||||||
class Hypernetwork:
|
class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None, enable_sizes=None):
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.sd_checkpoint = None
|
self.sd_checkpoint = None
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
|
self.layer_structure = layer_structure
|
||||||
|
self.add_layer_norm = add_layer_norm
|
||||||
|
|
||||||
for size in enable_sizes or []:
|
for size in enable_sizes or []:
|
||||||
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
|
self.layers[size] = (
|
||||||
|
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
||||||
|
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
||||||
|
)
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
res = []
|
res = []
|
||||||
@ -128,6 +121,8 @@ class Hypernetwork:
|
|||||||
|
|
||||||
state_dict['step'] = self.step
|
state_dict['step'] = self.step
|
||||||
state_dict['name'] = self.name
|
state_dict['name'] = self.name
|
||||||
|
state_dict['layer_structure'] = self.layer_structure
|
||||||
|
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
|
|
||||||
@ -142,10 +137,15 @@ class Hypernetwork:
|
|||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
|
self.layers[size] = (
|
||||||
|
HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
||||||
|
HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
||||||
|
)
|
||||||
|
|
||||||
self.name = state_dict.get('name', self.name)
|
self.name = state_dict.get('name', self.name)
|
||||||
self.step = state_dict.get('step', 0)
|
self.step = state_dict.get('step', 0)
|
||||||
|
self.layer_structure = state_dict.get('layer_structure', None)
|
||||||
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
||||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||||
|
|
||||||
|
@ -9,11 +9,16 @@ from modules import sd_hijack, shared, devices
|
|||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes):
|
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
|
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||||
|
name=name,
|
||||||
|
enable_sizes=[int(x) for x in enable_sizes],
|
||||||
|
layer_structure=layer_structure,
|
||||||
|
add_layer_norm=add_layer_norm,
|
||||||
|
)
|
||||||
hypernet.save(fn)
|
hypernet.save(fn)
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
@ -260,8 +260,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||||
"sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}),
|
|
||||||
"sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."),
|
|
||||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
|
@ -458,14 +458,14 @@ def create_toprow(is_img2img):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
|
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
|
||||||
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
|
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
|
||||||
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1198,6 +1198,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
with gr.Tab(label="Create hypernetwork"):
|
with gr.Tab(label="Create hypernetwork"):
|
||||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||||
|
new_hypernetwork_layer_structure = gr.Dropdown(label="Hypernetwork layer structure", choices=[(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)])
|
||||||
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
@ -1280,6 +1282,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
inputs=[
|
inputs=[
|
||||||
new_hypernetwork_name,
|
new_hypernetwork_name,
|
||||||
new_hypernetwork_sizes,
|
new_hypernetwork_sizes,
|
||||||
|
new_hypernetwork_layer_structure,
|
||||||
|
new_hypernetwork_add_layer_norm,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
|
8
webui.py
8
webui.py
@ -85,9 +85,7 @@ def initialize():
|
|||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
||||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
||||||
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
||||||
shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure)
|
|
||||||
shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm)
|
|
||||||
|
|
||||||
# make the program just exit at ctrl+c without waiting for anything
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
def sigint_handler(sig, frame):
|
def sigint_handler(sig, frame):
|
||||||
print(f'Interrupted with signal {sig} in {frame}')
|
print(f'Interrupted with signal {sig} in {frame}')
|
||||||
@ -142,7 +140,7 @@ def webui(launch_api=False):
|
|||||||
create_api(app)
|
create_api(app)
|
||||||
|
|
||||||
wait_on_server(demo)
|
wait_on_server(demo)
|
||||||
|
|
||||||
sd_samplers.set_samplers()
|
sd_samplers.set_samplers()
|
||||||
|
|
||||||
print('Reloading Custom Scripts')
|
print('Reloading Custom Scripts')
|
||||||
@ -160,4 +158,4 @@ if __name__ == "__main__":
|
|||||||
if cmd_opts.nowebui:
|
if cmd_opts.nowebui:
|
||||||
api_only()
|
api_only()
|
||||||
else:
|
else:
|
||||||
webui(cmd_opts.api)
|
webui(cmd_opts.api)
|
||||||
|
Loading…
Reference in New Issue
Block a user