diff --git a/modules/extras.py b/modules/extras.py index 7407bfe3..a03d558e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -3,6 +3,7 @@ import math import os import sys import traceback +import shutil import numpy as np from PIL import Image @@ -248,7 +249,32 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): +def create_config(ckpt_result, config_source, a, b, c): + def config(x): + return sd_models.find_checkpoint_config(x) if x else None + + if config_source == 0: + cfg = config(a) or config(b) or config(c) + elif config_source == 1: + cfg = config(b) + elif config_source == 2: + cfg = config(c) + else: + cfg = None + + if cfg is None: + return + + filename, _ = os.path.splitext(ckpt_result) + checkpoint_filename = filename + ".yaml" + + print("Copying config:") + print(" from:", cfg) + print(" to:", checkpoint_filename) + shutil.copyfile(cfg, checkpoint_filename) + + +def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): shared.state.begin() shared.state.job = 'model-merge' @@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam sd_models.list_models() + create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + print("Checkpoint saved.") shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.end() diff --git a/modules/ui.py b/modules/ui.py index 3c458ce8..82f5dd7c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1129,7 +1129,7 @@ def create_ui(): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - with gr.Row(): + with FormRow(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") @@ -1143,11 +1143,13 @@ def create_ui(): interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - with gr.Row(): + with FormRow(): checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) @@ -1703,6 +1705,7 @@ def create_ui(): save_as_half, custom_name, checkpoint_format, + config_source, ], outputs=[ submit_result,