add option to discard weights in checkpoint merger UI
This commit is contained in:
parent
0792fae078
commit
112416d041
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import shutil
|
import shutil
|
||||||
@ -285,7 +286,7 @@ def to_half(tensor, enable):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'model-merge'
|
shared.state.job = 'model-merge'
|
||||||
|
|
||||||
@ -430,6 +431,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
for key in theta_0.keys():
|
for key in theta_0.keys():
|
||||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||||
|
|
||||||
|
if discard_weights:
|
||||||
|
regex = re.compile(discard_weights)
|
||||||
|
for key in list(theta_0):
|
||||||
|
if re.search(regex, key):
|
||||||
|
theta_0.pop(key, None)
|
||||||
|
|
||||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||||
|
|
||||||
filename = filename_generator() if custom_name == '' else custom_name
|
filename = filename_generator() if custom_name == '' else custom_name
|
||||||
|
@ -1248,6 +1248,9 @@ def create_ui():
|
|||||||
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
|
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
|
||||||
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
|
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||||
|
|
||||||
@ -1838,6 +1841,7 @@ def create_ui():
|
|||||||
checkpoint_format,
|
checkpoint_format,
|
||||||
config_source,
|
config_source,
|
||||||
bake_in_vae,
|
bake_in_vae,
|
||||||
|
discard_weights,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
primary_model_name,
|
primary_model_name,
|
||||||
|
Loading…
Reference in New Issue
Block a user