v18: Save model as option added
This commit is contained in:
parent
fc22813b8f
commit
f459c32a3e
@ -129,6 +129,11 @@ Drop by the discord server for support: https://discord.com/channels/10415185624
|
|||||||
- Lord of the universe - cacoe (twitter: @cac0e)
|
- Lord of the universe - cacoe (twitter: @cac0e)
|
||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
|
* 12/17 (v18) update:
|
||||||
|
- Save model as option added to train_db_fixed.py
|
||||||
|
- Save model as option added to GUI
|
||||||
|
- Retire "Model conversion" parameters that was essentially performing the same function as the new `--save_model_as` parameter
|
||||||
* 12/17 (v17.2) update:
|
* 12/17 (v17.2) update:
|
||||||
- Adding new dataset balancing utility.
|
- Adding new dataset balancing utility.
|
||||||
* 12/17 (v17.1) update:
|
* 12/17 (v17.1) update:
|
||||||
|
@ -47,11 +47,8 @@ def save_configuration(
|
|||||||
save_precision,
|
save_precision,
|
||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
convert_to_safetensors,
|
|
||||||
convert_to_ckpt,
|
|
||||||
cache_latent,
|
cache_latent,
|
||||||
caption_extention,
|
caption_extention,
|
||||||
use_safetensors,
|
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
@ -59,6 +56,7 @@ def save_configuration(
|
|||||||
stop_text_encoder_training,
|
stop_text_encoder_training,
|
||||||
use_8bit_adam,
|
use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
|
save_model_as
|
||||||
):
|
):
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
|
|
||||||
@ -103,11 +101,8 @@ def save_configuration(
|
|||||||
'save_precision': save_precision,
|
'save_precision': save_precision,
|
||||||
'seed': seed,
|
'seed': seed,
|
||||||
'num_cpu_threads_per_process': num_cpu_threads_per_process,
|
'num_cpu_threads_per_process': num_cpu_threads_per_process,
|
||||||
'convert_to_safetensors': convert_to_safetensors,
|
|
||||||
'convert_to_ckpt': convert_to_ckpt,
|
|
||||||
'cache_latent': cache_latent,
|
'cache_latent': cache_latent,
|
||||||
'caption_extention': caption_extention,
|
'caption_extention': caption_extention,
|
||||||
'use_safetensors': use_safetensors,
|
|
||||||
'enable_bucket': enable_bucket,
|
'enable_bucket': enable_bucket,
|
||||||
'gradient_checkpointing': gradient_checkpointing,
|
'gradient_checkpointing': gradient_checkpointing,
|
||||||
'full_fp16': full_fp16,
|
'full_fp16': full_fp16,
|
||||||
@ -115,6 +110,7 @@ def save_configuration(
|
|||||||
'stop_text_encoder_training': stop_text_encoder_training,
|
'stop_text_encoder_training': stop_text_encoder_training,
|
||||||
'use_8bit_adam': use_8bit_adam,
|
'use_8bit_adam': use_8bit_adam,
|
||||||
'xformers': xformers,
|
'xformers': xformers,
|
||||||
|
'save_model_as': save_model_as
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save the data to the selected file
|
# Save the data to the selected file
|
||||||
@ -144,11 +140,8 @@ def open_configuration(
|
|||||||
save_precision,
|
save_precision,
|
||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
convert_to_safetensors,
|
|
||||||
convert_to_ckpt,
|
|
||||||
cache_latent,
|
cache_latent,
|
||||||
caption_extention,
|
caption_extention,
|
||||||
use_safetensors,
|
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
@ -156,6 +149,7 @@ def open_configuration(
|
|||||||
stop_text_encoder_training,
|
stop_text_encoder_training,
|
||||||
use_8bit_adam,
|
use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
|
save_model_as
|
||||||
):
|
):
|
||||||
|
|
||||||
original_file_path = file_path
|
original_file_path = file_path
|
||||||
@ -195,11 +189,8 @@ def open_configuration(
|
|||||||
my_data.get(
|
my_data.get(
|
||||||
'num_cpu_threads_per_process', num_cpu_threads_per_process
|
'num_cpu_threads_per_process', num_cpu_threads_per_process
|
||||||
),
|
),
|
||||||
my_data.get('convert_to_safetensors', convert_to_safetensors),
|
|
||||||
my_data.get('convert_to_ckpt', convert_to_ckpt),
|
|
||||||
my_data.get('cache_latent', cache_latent),
|
my_data.get('cache_latent', cache_latent),
|
||||||
my_data.get('caption_extention', caption_extention),
|
my_data.get('caption_extention', caption_extention),
|
||||||
my_data.get('use_safetensors', use_safetensors),
|
|
||||||
my_data.get('enable_bucket', enable_bucket),
|
my_data.get('enable_bucket', enable_bucket),
|
||||||
my_data.get('gradient_checkpointing', gradient_checkpointing),
|
my_data.get('gradient_checkpointing', gradient_checkpointing),
|
||||||
my_data.get('full_fp16', full_fp16),
|
my_data.get('full_fp16', full_fp16),
|
||||||
@ -207,6 +198,7 @@ def open_configuration(
|
|||||||
my_data.get('stop_text_encoder_training', stop_text_encoder_training),
|
my_data.get('stop_text_encoder_training', stop_text_encoder_training),
|
||||||
my_data.get('use_8bit_adam', use_8bit_adam),
|
my_data.get('use_8bit_adam', use_8bit_adam),
|
||||||
my_data.get('xformers', xformers),
|
my_data.get('xformers', xformers),
|
||||||
|
my_data.get('save_model_as', save_model_as)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -229,11 +221,8 @@ def train_model(
|
|||||||
save_precision,
|
save_precision,
|
||||||
seed,
|
seed,
|
||||||
num_cpu_threads_per_process,
|
num_cpu_threads_per_process,
|
||||||
convert_to_safetensors,
|
|
||||||
convert_to_ckpt,
|
|
||||||
cache_latent,
|
cache_latent,
|
||||||
caption_extention,
|
caption_extention,
|
||||||
use_safetensors,
|
|
||||||
enable_bucket,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16,
|
full_fp16,
|
||||||
@ -241,6 +230,7 @@ def train_model(
|
|||||||
stop_text_encoder_training_pct,
|
stop_text_encoder_training_pct,
|
||||||
use_8bit_adam,
|
use_8bit_adam,
|
||||||
xformers,
|
xformers,
|
||||||
|
save_model_as
|
||||||
):
|
):
|
||||||
def save_inference_file(output_dir, v2, v_parameterization):
|
def save_inference_file(output_dir, v2, v_parameterization):
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
@ -352,8 +342,6 @@ def train_model(
|
|||||||
run_cmd += ' --v_parameterization'
|
run_cmd += ' --v_parameterization'
|
||||||
if cache_latent:
|
if cache_latent:
|
||||||
run_cmd += ' --cache_latents'
|
run_cmd += ' --cache_latents'
|
||||||
if use_safetensors:
|
|
||||||
run_cmd += ' --use_safetensors'
|
|
||||||
if enable_bucket:
|
if enable_bucket:
|
||||||
run_cmd += ' --enable_bucket'
|
run_cmd += ' --enable_bucket'
|
||||||
if gradient_checkpointing:
|
if gradient_checkpointing:
|
||||||
@ -388,39 +376,20 @@ def train_model(
|
|||||||
run_cmd += f' --logging_dir={logging_dir}'
|
run_cmd += f' --logging_dir={logging_dir}'
|
||||||
run_cmd += f' --caption_extention={caption_extention}'
|
run_cmd += f' --caption_extention={caption_extention}'
|
||||||
run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}'
|
run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}'
|
||||||
|
if not save_model_as == 'same as source model':
|
||||||
|
run_cmd += f' --save_model_as={save_model_as}'
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
subprocess.run(run_cmd)
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
# check if output_dir/last is a directory... therefore it is a diffuser model
|
# check if output_dir/last is a folder... therefore it is a diffuser model
|
||||||
last_dir = pathlib.Path(f'{output_dir}/last')
|
last_dir = pathlib.Path(f'{output_dir}/last')
|
||||||
print(last_dir)
|
|
||||||
if last_dir.is_dir():
|
|
||||||
if convert_to_ckpt:
|
|
||||||
print(f'Converting diffuser model {last_dir} to {last_dir}.ckpt')
|
|
||||||
os.system(
|
|
||||||
f'python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.ckpt --{save_precision}'
|
|
||||||
)
|
|
||||||
|
|
||||||
save_inference_file(output_dir, v2, v_parameterization)
|
if not last_dir.is_dir():
|
||||||
|
|
||||||
if convert_to_safetensors:
|
|
||||||
print(
|
|
||||||
f'Converting diffuser model {last_dir} to {last_dir}.safetensors'
|
|
||||||
)
|
|
||||||
os.system(
|
|
||||||
f'python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.safetensors --{save_precision}'
|
|
||||||
)
|
|
||||||
|
|
||||||
save_inference_file(output_dir, v2, v_parameterization)
|
|
||||||
else:
|
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
save_inference_file(output_dir, v2, v_parameterization)
|
save_inference_file(output_dir, v2, v_parameterization)
|
||||||
|
|
||||||
# Return the values of the variables as a dictionary
|
|
||||||
# return
|
|
||||||
|
|
||||||
|
|
||||||
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
||||||
# define a list of substrings to search for
|
# define a list of substrings to search for
|
||||||
@ -533,6 +502,17 @@ with interface:
|
|||||||
'CompVis/stable-diffusion-v1-4',
|
'CompVis/stable-diffusion-v1-4',
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
save_model_as_dropdown = gr.Dropdown(
|
||||||
|
label='Save trained model as',
|
||||||
|
choices=[
|
||||||
|
'same as source model',
|
||||||
|
'ckpt',
|
||||||
|
'diffusers',
|
||||||
|
"diffusers_safetensors",
|
||||||
|
'safetensors',
|
||||||
|
],
|
||||||
|
value='same as source model'
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
v2_input = gr.Checkbox(label='v2', value=True)
|
v2_input = gr.Checkbox(label='v2', value=True)
|
||||||
v_parameterization_input = gr.Checkbox(
|
v_parameterization_input = gr.Checkbox(
|
||||||
@ -557,7 +537,7 @@ with interface:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_data_dir_input = gr.Textbox(
|
train_data_dir_input = gr.Textbox(
|
||||||
label='Image folder',
|
label='Image folder',
|
||||||
placeholder='Directory where the training folders containing the images are located',
|
placeholder='Folder where the training folders containing the images are located',
|
||||||
)
|
)
|
||||||
train_data_dir_input_folder = gr.Button(
|
train_data_dir_input_folder = gr.Button(
|
||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
@ -567,7 +547,7 @@ with interface:
|
|||||||
)
|
)
|
||||||
reg_data_dir_input = gr.Textbox(
|
reg_data_dir_input = gr.Textbox(
|
||||||
label='Regularisation folder',
|
label='Regularisation folder',
|
||||||
placeholder='(Optional) Directory where where the regularization folders containing the images are located',
|
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||||
)
|
)
|
||||||
reg_data_dir_input_folder = gr.Button(
|
reg_data_dir_input_folder = gr.Button(
|
||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
@ -577,8 +557,8 @@ with interface:
|
|||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
output_dir_input = gr.Textbox(
|
output_dir_input = gr.Textbox(
|
||||||
label='Output directory',
|
label='Output folder',
|
||||||
placeholder='Directory to output trained model',
|
placeholder='Folder to output trained model',
|
||||||
)
|
)
|
||||||
output_dir_input_folder = gr.Button(
|
output_dir_input_folder = gr.Button(
|
||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
@ -587,8 +567,8 @@ with interface:
|
|||||||
get_folder_path, outputs=output_dir_input
|
get_folder_path, outputs=output_dir_input
|
||||||
)
|
)
|
||||||
logging_dir_input = gr.Textbox(
|
logging_dir_input = gr.Textbox(
|
||||||
label='Logging directory',
|
label='Logging folder',
|
||||||
placeholder='Optional: enable logging and output TensorBoard log to this directory',
|
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||||
)
|
)
|
||||||
logging_dir_input_folder = gr.Button(
|
logging_dir_input_folder = gr.Button(
|
||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
@ -694,9 +674,6 @@ with interface:
|
|||||||
no_token_padding_input = gr.Checkbox(
|
no_token_padding_input = gr.Checkbox(
|
||||||
label='No token padding', value=False
|
label='No token padding', value=False
|
||||||
)
|
)
|
||||||
use_safetensors_input = gr.Checkbox(
|
|
||||||
label='Use safetensor when saving', value=False
|
|
||||||
)
|
|
||||||
|
|
||||||
gradient_checkpointing_input = gr.Checkbox(
|
gradient_checkpointing_input = gr.Checkbox(
|
||||||
label='Gradient checkpointing', value=False
|
label='Gradient checkpointing', value=False
|
||||||
@ -711,13 +688,6 @@ with interface:
|
|||||||
)
|
)
|
||||||
xformers_input = gr.Checkbox(label='Use xformers', value=True)
|
xformers_input = gr.Checkbox(label='Use xformers', value=True)
|
||||||
|
|
||||||
with gr.Tab('Model conversion'):
|
|
||||||
convert_to_safetensors_input = gr.Checkbox(
|
|
||||||
label='Convert to SafeTensors', value=True
|
|
||||||
)
|
|
||||||
convert_to_ckpt_input = gr.Checkbox(
|
|
||||||
label='Convert to CKPT', value=False
|
|
||||||
)
|
|
||||||
with gr.Tab('Utilities'):
|
with gr.Tab('Utilities'):
|
||||||
# Dreambooth folder creation tab
|
# Dreambooth folder creation tab
|
||||||
gradio_dreambooth_folder_creation_tab(
|
gradio_dreambooth_folder_creation_tab(
|
||||||
@ -729,6 +699,13 @@ with interface:
|
|||||||
# Captionning tab
|
# Captionning tab
|
||||||
gradio_caption_gui_tab()
|
gradio_caption_gui_tab()
|
||||||
gradio_dataset_balancing_tab()
|
gradio_dataset_balancing_tab()
|
||||||
|
# with gr.Tab('Model conversion'):
|
||||||
|
# convert_to_safetensors_input = gr.Checkbox(
|
||||||
|
# label='Convert to SafeTensors', value=True
|
||||||
|
# )
|
||||||
|
# convert_to_ckpt_input = gr.Checkbox(
|
||||||
|
# label='Convert to CKPT', value=False
|
||||||
|
# )
|
||||||
|
|
||||||
button_run = gr.Button('Train model')
|
button_run = gr.Button('Train model')
|
||||||
|
|
||||||
@ -754,11 +731,8 @@ with interface:
|
|||||||
save_precision_input,
|
save_precision_input,
|
||||||
seed_input,
|
seed_input,
|
||||||
num_cpu_threads_per_process_input,
|
num_cpu_threads_per_process_input,
|
||||||
convert_to_safetensors_input,
|
|
||||||
convert_to_ckpt_input,
|
|
||||||
cache_latent_input,
|
cache_latent_input,
|
||||||
caption_extention_input,
|
caption_extention_input,
|
||||||
use_safetensors_input,
|
|
||||||
enable_bucket_input,
|
enable_bucket_input,
|
||||||
gradient_checkpointing_input,
|
gradient_checkpointing_input,
|
||||||
full_fp16_input,
|
full_fp16_input,
|
||||||
@ -766,6 +740,7 @@ with interface:
|
|||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training_input,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam_input,
|
||||||
xformers_input,
|
xformers_input,
|
||||||
|
save_model_as_dropdown
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
config_file_name,
|
config_file_name,
|
||||||
@ -787,11 +762,8 @@ with interface:
|
|||||||
save_precision_input,
|
save_precision_input,
|
||||||
seed_input,
|
seed_input,
|
||||||
num_cpu_threads_per_process_input,
|
num_cpu_threads_per_process_input,
|
||||||
convert_to_safetensors_input,
|
|
||||||
convert_to_ckpt_input,
|
|
||||||
cache_latent_input,
|
cache_latent_input,
|
||||||
caption_extention_input,
|
caption_extention_input,
|
||||||
use_safetensors_input,
|
|
||||||
enable_bucket_input,
|
enable_bucket_input,
|
||||||
gradient_checkpointing_input,
|
gradient_checkpointing_input,
|
||||||
full_fp16_input,
|
full_fp16_input,
|
||||||
@ -799,6 +771,7 @@ with interface:
|
|||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training_input,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam_input,
|
||||||
xformers_input,
|
xformers_input,
|
||||||
|
save_model_as_dropdown
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -827,11 +800,8 @@ with interface:
|
|||||||
save_precision_input,
|
save_precision_input,
|
||||||
seed_input,
|
seed_input,
|
||||||
num_cpu_threads_per_process_input,
|
num_cpu_threads_per_process_input,
|
||||||
convert_to_safetensors_input,
|
|
||||||
convert_to_ckpt_input,
|
|
||||||
cache_latent_input,
|
cache_latent_input,
|
||||||
caption_extention_input,
|
caption_extention_input,
|
||||||
use_safetensors_input,
|
|
||||||
enable_bucket_input,
|
enable_bucket_input,
|
||||||
gradient_checkpointing_input,
|
gradient_checkpointing_input,
|
||||||
full_fp16_input,
|
full_fp16_input,
|
||||||
@ -839,6 +809,7 @@ with interface:
|
|||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training_input,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam_input,
|
||||||
xformers_input,
|
xformers_input,
|
||||||
|
save_model_as_dropdown
|
||||||
],
|
],
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
@ -866,11 +837,8 @@ with interface:
|
|||||||
save_precision_input,
|
save_precision_input,
|
||||||
seed_input,
|
seed_input,
|
||||||
num_cpu_threads_per_process_input,
|
num_cpu_threads_per_process_input,
|
||||||
convert_to_safetensors_input,
|
|
||||||
convert_to_ckpt_input,
|
|
||||||
cache_latent_input,
|
cache_latent_input,
|
||||||
caption_extention_input,
|
caption_extention_input,
|
||||||
use_safetensors_input,
|
|
||||||
enable_bucket_input,
|
enable_bucket_input,
|
||||||
gradient_checkpointing_input,
|
gradient_checkpointing_input,
|
||||||
full_fp16_input,
|
full_fp16_input,
|
||||||
@ -878,6 +846,7 @@ with interface:
|
|||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training_input,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam_input,
|
||||||
xformers_input,
|
xformers_input,
|
||||||
|
save_model_as_dropdown
|
||||||
],
|
],
|
||||||
outputs=[config_file_name],
|
outputs=[config_file_name],
|
||||||
)
|
)
|
||||||
@ -903,11 +872,8 @@ with interface:
|
|||||||
save_precision_input,
|
save_precision_input,
|
||||||
seed_input,
|
seed_input,
|
||||||
num_cpu_threads_per_process_input,
|
num_cpu_threads_per_process_input,
|
||||||
convert_to_safetensors_input,
|
|
||||||
convert_to_ckpt_input,
|
|
||||||
cache_latent_input,
|
cache_latent_input,
|
||||||
caption_extention_input,
|
caption_extention_input,
|
||||||
use_safetensors_input,
|
|
||||||
enable_bucket_input,
|
enable_bucket_input,
|
||||||
gradient_checkpointing_input,
|
gradient_checkpointing_input,
|
||||||
full_fp16_input,
|
full_fp16_input,
|
||||||
@ -915,6 +881,7 @@ with interface:
|
|||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training_input,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam_input,
|
||||||
xformers_input,
|
xformers_input,
|
||||||
|
save_model_as_dropdown
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,6 +35,10 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2
|
|||||||
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
||||||
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
||||||
|
|
||||||
|
# Diffusersの設定を読み込むための参照モデル
|
||||||
|
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
||||||
|
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
||||||
|
|
||||||
|
|
||||||
# region StableDiffusion->Diffusersの変換コード
|
# region StableDiffusion->Diffusersの変換コード
|
||||||
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
||||||
@ -973,7 +977,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
|||||||
keys = list(new_sd.keys())
|
keys = list(new_sd.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("transformer.resblocks.22."):
|
if key.startswith("transformer.resblocks.22."):
|
||||||
new_sd[key.replace(".22.", ".23.")] = new_sd[key]
|
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
||||||
|
|
||||||
# Diffusersに含まれない重みを作っておく
|
# Diffusersに含まれない重みを作っておく
|
||||||
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
||||||
@ -995,6 +999,7 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
|||||||
del state_dict["state_dict"]
|
del state_dict["state_dict"]
|
||||||
else:
|
else:
|
||||||
# 新しく作る
|
# 新しく作る
|
||||||
|
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
||||||
checkpoint = {}
|
checkpoint = {}
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
strict = False
|
strict = False
|
||||||
@ -1047,14 +1052,24 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
|||||||
|
|
||||||
|
|
||||||
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
||||||
|
if pretrained_model_name_or_path is None:
|
||||||
|
# load default settings for v1/v2
|
||||||
|
if v2:
|
||||||
|
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
||||||
|
else:
|
||||||
|
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
||||||
|
|
||||||
|
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
||||||
if vae is None:
|
if vae is None:
|
||||||
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
unet=unet,
|
unet=unet,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
|
scheduler=scheduler,
|
||||||
tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
|
tokenizer=tokenizer,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=None,
|
requires_safety_checker=None,
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
# v15: model_util update
|
# v15: model_util update
|
||||||
# v16: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0
|
# v16: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0
|
||||||
# v17: add fp16 gradient training (experimental)
|
# v17: add fp16 gradient training (experimental)
|
||||||
|
# v18: add save_model_as option
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
@ -670,8 +671,21 @@ def train(args):
|
|||||||
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||||
|
|
||||||
# モデル形式のオプション設定を確認する:
|
# モデル形式のオプション設定を確認する:
|
||||||
# v11からDiffUsersから直接落としてくるのもOK(ただし認証がいるやつは未対応)、またv11からDiffUsersも途中保存に対応した
|
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
|
||||||
use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
|
|
||||||
|
if load_stable_diffusion_format:
|
||||||
|
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||||
|
src_diffusers_model_path = None
|
||||||
|
else:
|
||||||
|
src_stable_diffusion_ckpt = None
|
||||||
|
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||||
|
|
||||||
|
if args.save_model_as is None:
|
||||||
|
save_stable_diffusion_format = load_stable_diffusion_format
|
||||||
|
use_safetensors = args.use_safetensors
|
||||||
|
else:
|
||||||
|
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
|
||||||
|
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||||
|
|
||||||
# 乱数系列を初期化する
|
# 乱数系列を初期化する
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
@ -691,7 +705,9 @@ def train(args):
|
|||||||
for cap_path in cap_paths:
|
for cap_path in cap_paths:
|
||||||
if os.path.isfile(cap_path):
|
if os.path.isfile(cap_path):
|
||||||
with open(cap_path, "rt", encoding='utf-8') as f:
|
with open(cap_path, "rt", encoding='utf-8') as f:
|
||||||
caption = f.readlines()[0].strip()
|
lines = f.readlines()
|
||||||
|
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||||
|
caption = lines[0].strip()
|
||||||
break
|
break
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
@ -845,7 +861,7 @@ def train(args):
|
|||||||
save_dtype = torch.float32
|
save_dtype = torch.float32
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
if use_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
print("load StableDiffusion checkpoint")
|
print("load StableDiffusion checkpoint")
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
@ -1079,17 +1095,17 @@ def train(args):
|
|||||||
if args.save_every_n_epochs is not None:
|
if args.save_every_n_epochs is not None:
|
||||||
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
|
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
|
||||||
print("saving checkpoint.")
|
print("saving checkpoint.")
|
||||||
if use_stable_diffusion_format:
|
if save_stable_diffusion_format:
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(args.use_safetensors, epoch + 1))
|
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1))
|
||||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
|
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
|
||||||
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype, vae)
|
src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae)
|
||||||
else:
|
else:
|
||||||
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
|
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder),
|
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder),
|
||||||
unwrap_model(unet), args.pretrained_model_name_or_path,
|
unwrap_model(unet), src_diffusers_model_path,
|
||||||
use_safetensors=args.use_safetensors)
|
use_safetensors=use_safetensors)
|
||||||
|
|
||||||
if args.save_state:
|
if args.save_state:
|
||||||
print("saving state.")
|
print("saving state.")
|
||||||
@ -1110,18 +1126,17 @@ def train(args):
|
|||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
if use_stable_diffusion_format:
|
if save_stable_diffusion_format:
|
||||||
ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(args.use_safetensors))
|
ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(use_safetensors))
|
||||||
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
||||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
||||||
args.pretrained_model_name_or_path, epoch, global_step, save_dtype, vae)
|
src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae)
|
||||||
else:
|
else:
|
||||||
# Create the pipeline using using the trained modules and save it.
|
|
||||||
print(f"save trained model as Diffusers to {args.output_dir}")
|
print(f"save trained model as Diffusers to {args.output_dir}")
|
||||||
out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
|
out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path,
|
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_diffusers_model_path,
|
||||||
use_safetensors=args.use_safetensors)
|
use_safetensors=use_safetensors)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
@ -1149,8 +1164,12 @@ if __name__ == '__main__':
|
|||||||
help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数")
|
help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数")
|
||||||
parser.add_argument("--output_dir", type=str, default=None,
|
parser.add_argument("--output_dir", type=str, default=None,
|
||||||
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||||
|
parser.add_argument("--save_precision", type=str, default=None,
|
||||||
|
choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)")
|
||||||
|
parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
|
||||||
|
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
|
||||||
parser.add_argument("--use_safetensors", action='store_true',
|
parser.add_argument("--use_safetensors", action='store_true',
|
||||||
help="use safetensors format to save / checkpoint、モデルをsafetensors形式で保存する")
|
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
||||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||||
parser.add_argument("--save_state", action="store_true",
|
parser.add_argument("--save_state", action="store_true",
|
||||||
@ -1195,8 +1214,6 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||||
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
||||||
parser.add_argument("--save_precision", type=str, default=None,
|
|
||||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)")
|
|
||||||
parser.add_argument("--clip_skip", type=int, default=None,
|
parser.add_argument("--clip_skip", type=int, default=None,
|
||||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||||
parser.add_argument("--logging_dir", type=str, default=None,
|
parser.add_argument("--logging_dir", type=str, default=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user