2022-09-19 22:13:12 +00:00
import os
2023-01-22 07:17:12 +00:00
import re
2023-01-11 06:10:07 +00:00
import shutil
2022-09-19 22:13:12 +00:00
2022-09-11 08:31:16 +00:00
2022-09-25 23:22:12 +00:00
import torch
2022-09-27 07:44:00 +00:00
import tqdm
2022-09-25 23:22:12 +00:00
2023-01-22 12:38:39 +00:00
from modules import shared , images , sd_models , sd_vae
2023-01-23 11:50:20 +00:00
from modules . ui_common import plaintext_to_html
2022-09-28 21:59:44 +00:00
import gradio as gr
2022-11-27 12:51:29 +00:00
import safetensors . torch
2022-09-13 16:23:55 +00:00
2022-09-11 08:31:16 +00:00
2022-09-17 06:07:07 +00:00
def run_pnginfo ( image ) :
2022-09-19 17:18:16 +00:00
if image is None :
return ' ' , ' ' , ' '
2022-11-27 13:28:32 +00:00
geninfo , items = images . read_info_from_image ( image )
items = { * * { ' parameters ' : geninfo } , * * items }
2022-11-24 02:39:09 +00:00
2022-09-11 08:31:16 +00:00
info = ' '
2022-09-13 16:23:55 +00:00
for key , text in items . items ( ) :
2022-09-11 08:31:16 +00:00
info + = f """
< div >
< p > < b > { plaintext_to_html ( str ( key ) ) } < / b > < / p >
< p > { plaintext_to_html ( str ( text ) ) } < / p >
< / div >
""" .strip()+ " \n "
if len ( info ) == 0 :
message = " Nothing found in the image. "
info = f " <div><p> { message } <p></div> "
2022-09-23 19:49:21 +00:00
return ' ' , geninfo , info
2022-09-25 23:22:12 +00:00
2023-01-11 06:10:07 +00:00
def create_config ( ckpt_result , config_source , a , b , c ) :
def config ( x ) :
2023-01-19 07:39:51 +00:00
res = sd_models . find_checkpoint_config ( x ) if x else None
return res if res != shared . sd_default_config else None
2023-01-11 06:10:07 +00:00
if config_source == 0 :
cfg = config ( a ) or config ( b ) or config ( c )
elif config_source == 1 :
cfg = config ( b )
elif config_source == 2 :
cfg = config ( c )
else :
cfg = None
if cfg is None :
return
filename , _ = os . path . splitext ( ckpt_result )
checkpoint_filename = filename + " .yaml "
print ( " Copying config: " )
print ( " from: " , cfg )
print ( " to: " , checkpoint_filename )
shutil . copyfile ( cfg , checkpoint_filename )
2023-01-19 15:24:17 +00:00
checkpoint_dict_skip_on_merge = [ " cond_stage_model.transformer.text_model.embeddings.position_ids " ]
2023-01-19 07:39:51 +00:00
2023-01-19 09:12:09 +00:00
def to_half ( tensor , enable ) :
if enable and tensor . dtype == torch . float :
return tensor . half ( )
return tensor
2023-01-22 07:17:12 +00:00
def run_modelmerger ( id_task , primary_model_name , secondary_model_name , tertiary_model_name , interp_method , multiplier , save_as_half , custom_name , checkpoint_format , config_source , bake_in_vae , discard_weights ) :
2023-01-03 15:21:51 +00:00
shared . state . begin ( )
shared . state . job = ' model-merge '
2023-01-19 05:53:50 +00:00
def fail ( message ) :
shared . state . textinfo = message
shared . state . end ( )
2023-01-19 06:25:37 +00:00
return [ * [ gr . update ( ) for _ in range ( 4 ) ] , message ]
2023-01-19 05:53:50 +00:00
2022-10-16 22:44:39 +00:00
def weighted_sum ( theta0 , theta1 , alpha ) :
2022-09-26 14:50:21 +00:00
return ( ( 1 - alpha ) * theta0 ) + ( alpha * theta1 )
2022-10-16 22:44:39 +00:00
def get_difference ( theta1 , theta2 ) :
return theta1 - theta2
def add_difference ( theta0 , theta1_2_diff , alpha ) :
return theta0 + ( alpha * theta1_2_diff )
2022-10-14 06:05:06 +00:00
2023-01-19 15:24:17 +00:00
def filename_weighted_sum ( ) :
2023-01-19 07:39:51 +00:00
a = primary_model_info . model_name
b = secondary_model_info . model_name
Ma = round ( 1 - multiplier , 2 )
Mb = round ( multiplier , 2 )
return f " { Ma } ( { a } ) + { Mb } ( { b } ) "
2023-01-19 15:24:17 +00:00
def filename_add_difference ( ) :
2023-01-19 07:39:51 +00:00
a = primary_model_info . model_name
b = secondary_model_info . model_name
c = tertiary_model_info . model_name
M = round ( multiplier , 2 )
return f " { a } + { M } ( { b } - { c } ) "
def filename_nothing ( ) :
return primary_model_info . model_name
theta_funcs = {
2023-01-19 15:24:17 +00:00
" Weighted sum " : ( filename_weighted_sum , None , weighted_sum ) ,
" Add difference " : ( filename_add_difference , get_difference , add_difference ) ,
2023-01-19 07:39:51 +00:00
" No interpolation " : ( filename_nothing , None , None ) ,
}
filename_generator , theta_func1 , theta_func2 = theta_funcs [ interp_method ]
shared . state . job_count = ( 1 if theta_func1 else 0 ) + ( 1 if theta_func2 else 0 )
2023-01-19 00:13:15 +00:00
if not primary_model_name :
2023-01-19 05:53:50 +00:00
return fail ( " Failed: Merging requires a primary model. " )
2023-01-19 00:13:15 +00:00
2022-09-28 21:59:44 +00:00
primary_model_info = sd_models . checkpoints_list [ primary_model_name ]
2023-01-19 00:13:15 +00:00
2023-01-19 07:39:51 +00:00
if theta_func2 and not secondary_model_name :
2023-01-19 05:53:50 +00:00
return fail ( " Failed: Merging requires a secondary model. " )
2022-09-27 07:44:00 +00:00
2023-01-19 07:39:51 +00:00
secondary_model_info = sd_models . checkpoints_list [ secondary_model_name ] if theta_func2 else None
2022-09-27 07:44:00 +00:00
2023-01-19 00:13:15 +00:00
if theta_func1 and not tertiary_model_name :
2023-01-19 05:53:50 +00:00
return fail ( f " Failed: Interpolation method ( { interp_method } ) requires a tertiary model. " )
2023-01-19 07:39:51 +00:00
2023-01-19 02:21:52 +00:00
tertiary_model_info = sd_models . checkpoints_list [ tertiary_model_name ] if theta_func1 else None
2023-01-19 00:13:15 +00:00
result_is_inpainting_model = False
2023-01-26 08:45:16 +00:00
result_is_pix2pix_model = False
2022-12-04 06:13:36 +00:00
2023-01-19 07:39:51 +00:00
if theta_func2 :
shared . state . textinfo = f " Loading B "
print ( f " Loading { secondary_model_info . filename } ... " )
theta_1 = sd_models . read_state_dict ( secondary_model_info . filename , map_location = ' cpu ' )
else :
theta_1 = None
2022-10-14 06:05:06 +00:00
2022-10-16 22:44:39 +00:00
if theta_func1 :
2023-01-19 07:39:51 +00:00
shared . state . textinfo = f " Loading C "
2022-12-04 06:13:36 +00:00
print ( f " Loading { tertiary_model_info . filename } ... " )
theta_2 = sd_models . read_state_dict ( tertiary_model_info . filename , map_location = ' cpu ' )
2023-01-19 07:39:51 +00:00
shared . state . textinfo = ' Merging B and C '
2023-01-19 06:25:37 +00:00
shared . state . sampling_steps = len ( theta_1 . keys ( ) )
2022-10-16 22:44:39 +00:00
for key in tqdm . tqdm ( theta_1 . keys ( ) ) :
2023-01-19 15:24:17 +00:00
if key in checkpoint_dict_skip_on_merge :
2023-01-19 07:39:51 +00:00
continue
2022-10-16 22:44:39 +00:00
if ' model ' in key :
2022-10-18 12:33:24 +00:00
if key in theta_2 :
t2 = theta_2 . get ( key , torch . zeros_like ( theta_1 [ key ] ) )
theta_1 [ key ] = theta_func1 ( theta_1 [ key ] , t2 )
else :
2022-10-18 13:05:52 +00:00
theta_1 [ key ] = torch . zeros_like ( theta_1 [ key ] )
2023-01-19 06:25:37 +00:00
shared . state . sampling_step + = 1
2022-12-04 06:13:36 +00:00
del theta_2
2023-01-19 06:25:37 +00:00
shared . state . nextjob ( )
2023-01-03 15:21:51 +00:00
shared . state . textinfo = f " Loading { primary_model_info . filename } ... "
2022-12-04 06:13:36 +00:00
print ( f " Loading { primary_model_info . filename } ... " )
theta_0 = sd_models . read_state_dict ( primary_model_info . filename , map_location = ' cpu ' )
print ( " Merging... " )
2023-01-19 07:39:51 +00:00
shared . state . textinfo = ' Merging A and B '
2023-01-19 06:25:37 +00:00
shared . state . sampling_steps = len ( theta_0 . keys ( ) )
2022-09-27 07:44:00 +00:00
for key in tqdm . tqdm ( theta_0 . keys ( ) ) :
2023-01-19 07:39:51 +00:00
if theta_1 and ' model ' in key and key in theta_1 :
2023-01-14 11:00:00 +00:00
2023-01-19 15:24:17 +00:00
if key in checkpoint_dict_skip_on_merge :
2023-01-14 11:00:00 +00:00
continue
2022-12-04 09:30:44 +00:00
a = theta_0 [ key ]
b = theta_1 [ key ]
2022-10-14 18:20:28 +00:00
2022-12-04 09:30:44 +00:00
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
if a . shape != b . shape and a . shape [ 0 : 1 ] + a . shape [ 2 : ] == b . shape [ 0 : 1 ] + b . shape [ 2 : ] :
if a . shape [ 1 ] == 4 and b . shape [ 1 ] == 9 :
raise RuntimeError ( " When merging inpainting model with a normal one, A must be the inpainting model. " )
2023-01-26 08:45:16 +00:00
if a . shape [ 1 ] == 8 and b . shape [ 1 ] == 4 : #If we have an InstructPix2Pix model...
print ( " Detected possible merge of instruct model with non-instruct model. " )
theta_0 [ key ] [ : , 0 : 4 , : , : ] = theta_func2 ( a [ : , 0 : 4 , : , : ] , b , multiplier ) #Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
result_is_pix2pix_model = True
else :
assert a . shape [ 1 ] == 9 and b . shape [ 1 ] == 4 , f " Bad dimensions for merged layer { key } : A= { a . shape } , B= { b . shape } "
theta_0 [ key ] [ : , 0 : 4 , : , : ] = theta_func2 ( a [ : , 0 : 4 , : , : ] , b , multiplier )
result_is_inpainting_model = True
2022-12-04 09:30:44 +00:00
else :
theta_0 [ key ] = theta_func2 ( a , b , multiplier )
2023-01-26 08:45:16 +00:00
2023-01-19 09:12:09 +00:00
theta_0 [ key ] = to_half ( theta_0 [ key ] , save_as_half )
2022-10-10 01:26:52 +00:00
2023-01-19 06:25:37 +00:00
shared . state . sampling_step + = 1
2023-01-19 07:39:51 +00:00
del theta_1
bake_in_vae_filename = sd_vae . vae_dict . get ( bake_in_vae , None )
if bake_in_vae_filename is not None :
print ( f " Baking in VAE from { bake_in_vae_filename } " )
shared . state . textinfo = ' Baking in VAE '
vae_dict = sd_vae . load_vae_dict ( bake_in_vae_filename , map_location = ' cpu ' )
2023-01-14 11:00:00 +00:00
2023-01-19 07:39:51 +00:00
for key in vae_dict . keys ( ) :
theta_0_key = ' first_stage_model. ' + key
if theta_0_key in theta_0 :
2023-01-19 09:12:09 +00:00
theta_0 [ theta_0_key ] = to_half ( vae_dict [ key ] , save_as_half )
2023-01-14 11:00:00 +00:00
2023-01-19 07:39:51 +00:00
del vae_dict
2022-09-27 07:44:00 +00:00
2023-01-19 09:12:09 +00:00
if save_as_half and not theta_func2 :
for key in theta_0 . keys ( ) :
theta_0 [ key ] = to_half ( theta_0 [ key ] , save_as_half )
2023-01-22 07:17:12 +00:00
if discard_weights :
regex = re . compile ( discard_weights )
for key in list ( theta_0 ) :
if re . search ( regex , key ) :
theta_0 . pop ( key , None )
2022-09-30 19:57:25 +00:00
ckpt_dir = shared . cmd_opts . ckpt_dir or sd_models . model_path
2023-01-19 07:39:51 +00:00
filename = filename_generator ( ) if custom_name == ' ' else custom_name
filename + = " .inpainting " if result_is_inpainting_model else " "
2023-01-26 08:45:16 +00:00
filename + = " .pix2pix " if result_is_pix2pix_model else " "
2023-01-19 07:39:51 +00:00
filename + = " . " + checkpoint_format
2022-12-04 09:30:44 +00:00
2022-09-30 19:57:25 +00:00
output_modelname = os . path . join ( ckpt_dir , filename )
2022-09-28 21:21:54 +00:00
2023-01-19 06:25:37 +00:00
shared . state . nextjob ( )
2023-01-19 07:39:51 +00:00
shared . state . textinfo = " Saving "
2022-09-27 07:44:00 +00:00
print ( f " Saving to { output_modelname } ... " )
2022-11-27 12:51:29 +00:00
_ , extension = os . path . splitext ( output_modelname )
if extension . lower ( ) == " .safetensors " :
safetensors . torch . save_file ( theta_0 , output_modelname , metadata = { " format " : " pt " } )
else :
torch . save ( theta_0 , output_modelname )
2022-09-27 07:44:00 +00:00
2022-09-28 21:59:44 +00:00
sd_models . list_models ( )
2023-01-11 06:10:07 +00:00
create_config ( output_modelname , config_source , primary_model_info , secondary_model_info , tertiary_model_info )
2023-01-19 07:39:51 +00:00
print ( f " Checkpoint saved to { output_modelname } . " )
shared . state . textinfo = " Checkpoint saved "
2023-01-03 15:21:51 +00:00
shared . state . end ( )
2023-01-19 06:25:37 +00:00
return [ * [ gr . Dropdown . update ( choices = sd_models . checkpoint_tiles ( ) ) for _ in range ( 4 ) ] , " Checkpoint saved to " + output_modelname ]