init job and add info to model merge
This commit is contained in:
parent
e9fb9bb0c2
commit
1d9dc48efd
@ -242,6 +242,9 @@ def run_pnginfo(image):
|
|||||||
|
|
||||||
|
|
||||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
||||||
|
shared.state.begin()
|
||||||
|
shared.state.job = 'model-merge'
|
||||||
|
|
||||||
def weighted_sum(theta0, theta1, alpha):
|
def weighted_sum(theta0, theta1, alpha):
|
||||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||||
|
|
||||||
@ -263,8 +266,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||||
|
|
||||||
if theta_func1 and not tertiary_model_info:
|
if theta_func1 and not tertiary_model_info:
|
||||||
|
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
|
||||||
|
shared.state.end()
|
||||||
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
|
||||||
print(f"Loading {secondary_model_info.filename}...")
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
@ -281,6 +287,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||||
del theta_2
|
del theta_2
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||||
print(f"Loading {primary_model_info.filename}...")
|
print(f"Loading {primary_model_info.filename}...")
|
||||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
@ -291,6 +298,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||||||
a = theta_0[key]
|
a = theta_0[key]
|
||||||
b = theta_1[key]
|
b = theta_1[key]
|
||||||
|
|
||||||
|
shared.state.textinfo = f'Merging layer {key}'
|
||||||
# this enables merging an inpainting model (A) with another one (B);
|
# this enables merging an inpainting model (A) with another one (B);
|
||||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||||
@ -303,8 +311,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||||
result_is_inpainting_model = True
|
result_is_inpainting_model = True
|
||||||
else:
|
else:
|
||||||
assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
|
|
||||||
|
|
||||||
theta_0[key] = theta_func2(a, b, multiplier)
|
theta_0[key] = theta_func2(a, b, multiplier)
|
||||||
|
|
||||||
if save_as_half:
|
if save_as_half:
|
||||||
@ -332,6 +338,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||||||
|
|
||||||
output_modelname = os.path.join(ckpt_dir, filename)
|
output_modelname = os.path.join(ckpt_dir, filename)
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
_, extension = os.path.splitext(output_modelname)
|
_, extension = os.path.splitext(output_modelname)
|
||||||
@ -343,4 +350,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
|||||||
sd_models.list_models()
|
sd_models.list_models()
|
||||||
|
|
||||||
print("Checkpoint saved.")
|
print("Checkpoint saved.")
|
||||||
|
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||||
|
Loading…
Reference in New Issue
Block a user