Embed model merge metadata in .safetensors file
This commit is contained in:
parent
22bcc7be42
commit
d132481058
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -71,7 +72,7 @@ def to_half(tensor, enable):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'model-merge'
|
shared.state.job = 'model-merge'
|
||||||
|
|
||||||
@ -241,13 +242,52 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
shared.state.textinfo = "Saving"
|
shared.state.textinfo = "Saving"
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
|
metadata = {"format": "pt", "models": {}, "merge_recipe": None}
|
||||||
|
|
||||||
|
if save_metadata:
|
||||||
|
merge_recipe = {
|
||||||
|
"primary_model_hash": primary_model_info.sha256,
|
||||||
|
"secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
|
||||||
|
"tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
|
||||||
|
"interp_method": interp_method,
|
||||||
|
"multiplier": multiplier,
|
||||||
|
"save_as_half": save_as_half,
|
||||||
|
"custom_name": custom_name,
|
||||||
|
"config_source": config_source,
|
||||||
|
"bake_in_vae": bake_in_vae,
|
||||||
|
"discard_weights": discard_weights,
|
||||||
|
"is_inpainting": result_is_inpainting_model,
|
||||||
|
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
||||||
|
}
|
||||||
|
metadata["merge_recipe"] = json.dumps(merge_recipe)
|
||||||
|
|
||||||
|
def add_model_metadata(checkpoint_info):
|
||||||
|
metadata["models"][checkpoint_info.sha256] = {
|
||||||
|
"name": checkpoint_info.name,
|
||||||
|
"legacy_hash": checkpoint_info.hash,
|
||||||
|
"merge_recipe": checkpoint_info.metadata.get("merge_recipe", None)
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata["models"].update(checkpoint_info.metadata.get("models", {}))
|
||||||
|
|
||||||
|
add_model_metadata(primary_model_info)
|
||||||
|
if secondary_model_info:
|
||||||
|
add_model_metadata(secondary_model_info)
|
||||||
|
if tertiary_model_info:
|
||||||
|
add_model_metadata(tertiary_model_info)
|
||||||
|
|
||||||
|
metadata["models"] = json.dumps(metadata["models"])
|
||||||
|
|
||||||
_, extension = os.path.splitext(output_modelname)
|
_, extension = os.path.splitext(output_modelname)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
|
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(theta_0, output_modelname)
|
torch.save(theta_0, output_modelname)
|
||||||
|
|
||||||
sd_models.list_models()
|
sd_models.list_models()
|
||||||
|
created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
|
||||||
|
if created_model:
|
||||||
|
created_model.calculate_shorthash()
|
||||||
|
|
||||||
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||||
|
|
||||||
|
@ -52,6 +52,15 @@ class CheckpointInfo:
|
|||||||
|
|
||||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||||
|
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
|
_, ext = os.path.splitext(self.filename)
|
||||||
|
if ext.lower() == ".safetensors":
|
||||||
|
try:
|
||||||
|
self.metadata = read_metadata_from_safetensors(filename)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading checkpoint metadata: {filename}")
|
||||||
|
|
||||||
def register(self):
|
def register(self):
|
||||||
checkpoints_list[self.title] = self
|
checkpoints_list[self.title] = self
|
||||||
for id in self.ids:
|
for id in self.ids:
|
||||||
|
@ -1019,8 +1019,9 @@ def create_ui():
|
|||||||
interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
|
interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||||
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||||
|
save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -1658,6 +1659,7 @@ def create_ui():
|
|||||||
config_source,
|
config_source,
|
||||||
bake_in_vae,
|
bake_in_vae,
|
||||||
discard_weights,
|
discard_weights,
|
||||||
|
save_metadata,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
primary_model_name,
|
primary_model_name,
|
||||||
|
Loading…
Reference in New Issue
Block a user