v18: Save model as option added

This commit is contained in:
bmaltais 2022-12-17 20:36:31 -05:00
parent fc22813b8f
commit f459c32a3e
4 changed files with 97 additions and 93 deletions

View File

@ -129,6 +129,11 @@ Drop by the discord server for support: https://discord.com/channels/10415185624
- Lord of the universe - cacoe (twitter: @cac0e) - Lord of the universe - cacoe (twitter: @cac0e)
## Change history ## Change history
* 12/17 (v18) update:
- Save model as option added to train_db_fixed.py
- Save model as option added to GUI
- Retire "Model conversion" parameters that was essentially performing the same function as the new `--save_model_as` parameter
* 12/17 (v17.2) update: * 12/17 (v17.2) update:
- Adding new dataset balancing utility. - Adding new dataset balancing utility.
* 12/17 (v17.1) update: * 12/17 (v17.1) update:

View File

@ -47,11 +47,8 @@ def save_configuration(
save_precision, save_precision,
seed, seed,
num_cpu_threads_per_process, num_cpu_threads_per_process,
convert_to_safetensors,
convert_to_ckpt,
cache_latent, cache_latent,
caption_extention, caption_extention,
use_safetensors,
enable_bucket, enable_bucket,
gradient_checkpointing, gradient_checkpointing,
full_fp16, full_fp16,
@ -59,6 +56,7 @@ def save_configuration(
stop_text_encoder_training, stop_text_encoder_training,
use_8bit_adam, use_8bit_adam,
xformers, xformers,
save_model_as
): ):
original_file_path = file_path original_file_path = file_path
@ -103,11 +101,8 @@ def save_configuration(
'save_precision': save_precision, 'save_precision': save_precision,
'seed': seed, 'seed': seed,
'num_cpu_threads_per_process': num_cpu_threads_per_process, 'num_cpu_threads_per_process': num_cpu_threads_per_process,
'convert_to_safetensors': convert_to_safetensors,
'convert_to_ckpt': convert_to_ckpt,
'cache_latent': cache_latent, 'cache_latent': cache_latent,
'caption_extention': caption_extention, 'caption_extention': caption_extention,
'use_safetensors': use_safetensors,
'enable_bucket': enable_bucket, 'enable_bucket': enable_bucket,
'gradient_checkpointing': gradient_checkpointing, 'gradient_checkpointing': gradient_checkpointing,
'full_fp16': full_fp16, 'full_fp16': full_fp16,
@ -115,6 +110,7 @@ def save_configuration(
'stop_text_encoder_training': stop_text_encoder_training, 'stop_text_encoder_training': stop_text_encoder_training,
'use_8bit_adam': use_8bit_adam, 'use_8bit_adam': use_8bit_adam,
'xformers': xformers, 'xformers': xformers,
'save_model_as': save_model_as
} }
# Save the data to the selected file # Save the data to the selected file
@ -144,11 +140,8 @@ def open_configuration(
save_precision, save_precision,
seed, seed,
num_cpu_threads_per_process, num_cpu_threads_per_process,
convert_to_safetensors,
convert_to_ckpt,
cache_latent, cache_latent,
caption_extention, caption_extention,
use_safetensors,
enable_bucket, enable_bucket,
gradient_checkpointing, gradient_checkpointing,
full_fp16, full_fp16,
@ -156,6 +149,7 @@ def open_configuration(
stop_text_encoder_training, stop_text_encoder_training,
use_8bit_adam, use_8bit_adam,
xformers, xformers,
save_model_as
): ):
original_file_path = file_path original_file_path = file_path
@ -195,11 +189,8 @@ def open_configuration(
my_data.get( my_data.get(
'num_cpu_threads_per_process', num_cpu_threads_per_process 'num_cpu_threads_per_process', num_cpu_threads_per_process
), ),
my_data.get('convert_to_safetensors', convert_to_safetensors),
my_data.get('convert_to_ckpt', convert_to_ckpt),
my_data.get('cache_latent', cache_latent), my_data.get('cache_latent', cache_latent),
my_data.get('caption_extention', caption_extention), my_data.get('caption_extention', caption_extention),
my_data.get('use_safetensors', use_safetensors),
my_data.get('enable_bucket', enable_bucket), my_data.get('enable_bucket', enable_bucket),
my_data.get('gradient_checkpointing', gradient_checkpointing), my_data.get('gradient_checkpointing', gradient_checkpointing),
my_data.get('full_fp16', full_fp16), my_data.get('full_fp16', full_fp16),
@ -207,6 +198,7 @@ def open_configuration(
my_data.get('stop_text_encoder_training', stop_text_encoder_training), my_data.get('stop_text_encoder_training', stop_text_encoder_training),
my_data.get('use_8bit_adam', use_8bit_adam), my_data.get('use_8bit_adam', use_8bit_adam),
my_data.get('xformers', xformers), my_data.get('xformers', xformers),
my_data.get('save_model_as', save_model_as)
) )
@ -229,11 +221,8 @@ def train_model(
save_precision, save_precision,
seed, seed,
num_cpu_threads_per_process, num_cpu_threads_per_process,
convert_to_safetensors,
convert_to_ckpt,
cache_latent, cache_latent,
caption_extention, caption_extention,
use_safetensors,
enable_bucket, enable_bucket,
gradient_checkpointing, gradient_checkpointing,
full_fp16, full_fp16,
@ -241,6 +230,7 @@ def train_model(
stop_text_encoder_training_pct, stop_text_encoder_training_pct,
use_8bit_adam, use_8bit_adam,
xformers, xformers,
save_model_as
): ):
def save_inference_file(output_dir, v2, v_parameterization): def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required # Copy inference model for v2 if required
@ -352,8 +342,6 @@ def train_model(
run_cmd += ' --v_parameterization' run_cmd += ' --v_parameterization'
if cache_latent: if cache_latent:
run_cmd += ' --cache_latents' run_cmd += ' --cache_latents'
if use_safetensors:
run_cmd += ' --use_safetensors'
if enable_bucket: if enable_bucket:
run_cmd += ' --enable_bucket' run_cmd += ' --enable_bucket'
if gradient_checkpointing: if gradient_checkpointing:
@ -388,39 +376,20 @@ def train_model(
run_cmd += f' --logging_dir={logging_dir}' run_cmd += f' --logging_dir={logging_dir}'
run_cmd += f' --caption_extention={caption_extention}' run_cmd += f' --caption_extention={caption_extention}'
run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}' run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}'
if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}'
print(run_cmd) print(run_cmd)
# Run the command # Run the command
subprocess.run(run_cmd) subprocess.run(run_cmd)
# check if output_dir/last is a directory... therefore it is a diffuser model # check if output_dir/last is a folder... therefore it is a diffuser model
last_dir = pathlib.Path(f'{output_dir}/last') last_dir = pathlib.Path(f'{output_dir}/last')
print(last_dir)
if last_dir.is_dir():
if convert_to_ckpt:
print(f'Converting diffuser model {last_dir} to {last_dir}.ckpt')
os.system(
f'python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.ckpt --{save_precision}'
)
save_inference_file(output_dir, v2, v_parameterization) if not last_dir.is_dir():
if convert_to_safetensors:
print(
f'Converting diffuser model {last_dir} to {last_dir}.safetensors'
)
os.system(
f'python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.safetensors --{save_precision}'
)
save_inference_file(output_dir, v2, v_parameterization)
else:
# Copy inference model for v2 if required # Copy inference model for v2 if required
save_inference_file(output_dir, v2, v_parameterization) save_inference_file(output_dir, v2, v_parameterization)
# Return the values of the variables as a dictionary
# return
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
# define a list of substrings to search for # define a list of substrings to search for
@ -533,6 +502,17 @@ with interface:
'CompVis/stable-diffusion-v1-4', 'CompVis/stable-diffusion-v1-4',
], ],
) )
save_model_as_dropdown = gr.Dropdown(
label='Save trained model as',
choices=[
'same as source model',
'ckpt',
'diffusers',
"diffusers_safetensors",
'safetensors',
],
value='same as source model'
)
with gr.Row(): with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True) v2_input = gr.Checkbox(label='v2', value=True)
v_parameterization_input = gr.Checkbox( v_parameterization_input = gr.Checkbox(
@ -557,7 +537,7 @@ with interface:
with gr.Row(): with gr.Row():
train_data_dir_input = gr.Textbox( train_data_dir_input = gr.Textbox(
label='Image folder', label='Image folder',
placeholder='Directory where the training folders containing the images are located', placeholder='Folder where the training folders containing the images are located',
) )
train_data_dir_input_folder = gr.Button( train_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small' '📂', elem_id='open_folder_small'
@ -567,7 +547,7 @@ with interface:
) )
reg_data_dir_input = gr.Textbox( reg_data_dir_input = gr.Textbox(
label='Regularisation folder', label='Regularisation folder',
placeholder='(Optional) Directory where where the regularization folders containing the images are located', placeholder='(Optional) Folder where where the regularization folders containing the images are located',
) )
reg_data_dir_input_folder = gr.Button( reg_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small' '📂', elem_id='open_folder_small'
@ -577,8 +557,8 @@ with interface:
) )
with gr.Row(): with gr.Row():
output_dir_input = gr.Textbox( output_dir_input = gr.Textbox(
label='Output directory', label='Output folder',
placeholder='Directory to output trained model', placeholder='Folder to output trained model',
) )
output_dir_input_folder = gr.Button( output_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small' '📂', elem_id='open_folder_small'
@ -587,8 +567,8 @@ with interface:
get_folder_path, outputs=output_dir_input get_folder_path, outputs=output_dir_input
) )
logging_dir_input = gr.Textbox( logging_dir_input = gr.Textbox(
label='Logging directory', label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this directory', placeholder='Optional: enable logging and output TensorBoard log to this folder',
) )
logging_dir_input_folder = gr.Button( logging_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small' '📂', elem_id='open_folder_small'
@ -694,9 +674,6 @@ with interface:
no_token_padding_input = gr.Checkbox( no_token_padding_input = gr.Checkbox(
label='No token padding', value=False label='No token padding', value=False
) )
use_safetensors_input = gr.Checkbox(
label='Use safetensor when saving', value=False
)
gradient_checkpointing_input = gr.Checkbox( gradient_checkpointing_input = gr.Checkbox(
label='Gradient checkpointing', value=False label='Gradient checkpointing', value=False
@ -711,13 +688,6 @@ with interface:
) )
xformers_input = gr.Checkbox(label='Use xformers', value=True) xformers_input = gr.Checkbox(label='Use xformers', value=True)
with gr.Tab('Model conversion'):
convert_to_safetensors_input = gr.Checkbox(
label='Convert to SafeTensors', value=True
)
convert_to_ckpt_input = gr.Checkbox(
label='Convert to CKPT', value=False
)
with gr.Tab('Utilities'): with gr.Tab('Utilities'):
# Dreambooth folder creation tab # Dreambooth folder creation tab
gradio_dreambooth_folder_creation_tab( gradio_dreambooth_folder_creation_tab(
@ -729,6 +699,13 @@ with interface:
# Captionning tab # Captionning tab
gradio_caption_gui_tab() gradio_caption_gui_tab()
gradio_dataset_balancing_tab() gradio_dataset_balancing_tab()
# with gr.Tab('Model conversion'):
# convert_to_safetensors_input = gr.Checkbox(
# label='Convert to SafeTensors', value=True
# )
# convert_to_ckpt_input = gr.Checkbox(
# label='Convert to CKPT', value=False
# )
button_run = gr.Button('Train model') button_run = gr.Button('Train model')
@ -754,11 +731,8 @@ with interface:
save_precision_input, save_precision_input,
seed_input, seed_input,
num_cpu_threads_per_process_input, num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input, cache_latent_input,
caption_extention_input, caption_extention_input,
use_safetensors_input,
enable_bucket_input, enable_bucket_input,
gradient_checkpointing_input, gradient_checkpointing_input,
full_fp16_input, full_fp16_input,
@ -766,6 +740,7 @@ with interface:
stop_text_encoder_training_input, stop_text_encoder_training_input,
use_8bit_adam_input, use_8bit_adam_input,
xformers_input, xformers_input,
save_model_as_dropdown
], ],
outputs=[ outputs=[
config_file_name, config_file_name,
@ -787,11 +762,8 @@ with interface:
save_precision_input, save_precision_input,
seed_input, seed_input,
num_cpu_threads_per_process_input, num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input, cache_latent_input,
caption_extention_input, caption_extention_input,
use_safetensors_input,
enable_bucket_input, enable_bucket_input,
gradient_checkpointing_input, gradient_checkpointing_input,
full_fp16_input, full_fp16_input,
@ -799,6 +771,7 @@ with interface:
stop_text_encoder_training_input, stop_text_encoder_training_input,
use_8bit_adam_input, use_8bit_adam_input,
xformers_input, xformers_input,
save_model_as_dropdown
], ],
) )
@ -827,11 +800,8 @@ with interface:
save_precision_input, save_precision_input,
seed_input, seed_input,
num_cpu_threads_per_process_input, num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input, cache_latent_input,
caption_extention_input, caption_extention_input,
use_safetensors_input,
enable_bucket_input, enable_bucket_input,
gradient_checkpointing_input, gradient_checkpointing_input,
full_fp16_input, full_fp16_input,
@ -839,6 +809,7 @@ with interface:
stop_text_encoder_training_input, stop_text_encoder_training_input,
use_8bit_adam_input, use_8bit_adam_input,
xformers_input, xformers_input,
save_model_as_dropdown
], ],
outputs=[config_file_name], outputs=[config_file_name],
) )
@ -866,11 +837,8 @@ with interface:
save_precision_input, save_precision_input,
seed_input, seed_input,
num_cpu_threads_per_process_input, num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input, cache_latent_input,
caption_extention_input, caption_extention_input,
use_safetensors_input,
enable_bucket_input, enable_bucket_input,
gradient_checkpointing_input, gradient_checkpointing_input,
full_fp16_input, full_fp16_input,
@ -878,6 +846,7 @@ with interface:
stop_text_encoder_training_input, stop_text_encoder_training_input,
use_8bit_adam_input, use_8bit_adam_input,
xformers_input, xformers_input,
save_model_as_dropdown
], ],
outputs=[config_file_name], outputs=[config_file_name],
) )
@ -903,11 +872,8 @@ with interface:
save_precision_input, save_precision_input,
seed_input, seed_input,
num_cpu_threads_per_process_input, num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input, cache_latent_input,
caption_extention_input, caption_extention_input,
use_safetensors_input,
enable_bucket_input, enable_bucket_input,
gradient_checkpointing_input, gradient_checkpointing_input,
full_fp16_input, full_fp16_input,
@ -915,6 +881,7 @@ with interface:
stop_text_encoder_training_input, stop_text_encoder_training_input,
use_8bit_adam_input, use_8bit_adam_input,
xformers_input, xformers_input,
save_model_as_dropdown
], ],
) )

View File

@ -35,6 +35,10 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
V2_UNET_PARAMS_CONTEXT_DIM = 1024 V2_UNET_PARAMS_CONTEXT_DIM = 1024
# Diffusersの設定を読み込むための参照モデル
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
# region StableDiffusion->Diffusersの変換コード # region StableDiffusion->Diffusersの変換コード
# convert_original_stable_diffusion_to_diffusers をコピーして修正しているASL 2.0 # convert_original_stable_diffusion_to_diffusers をコピーして修正しているASL 2.0
@ -973,7 +977,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
keys = list(new_sd.keys()) keys = list(new_sd.keys())
for key in keys: for key in keys:
if key.startswith("transformer.resblocks.22."): if key.startswith("transformer.resblocks.22."):
new_sd[key.replace(".22.", ".23.")] = new_sd[key] new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
# Diffusersに含まれない重みを作っておく # Diffusersに含まれない重みを作っておく
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
@ -995,6 +999,7 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
del state_dict["state_dict"] del state_dict["state_dict"]
else: else:
# 新しく作る # 新しく作る
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
checkpoint = {} checkpoint = {}
state_dict = {} state_dict = {}
strict = False strict = False
@ -1047,14 +1052,24 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
if pretrained_model_name_or_path is None:
# load default settings for v1/v2
if v2:
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
else:
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
if vae is None: if vae is None:
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline(
unet=unet, unet=unet,
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae, vae=vae,
scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"), scheduler=scheduler,
tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"), tokenizer=tokenizer,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=None, requires_safety_checker=None,

View File

@ -16,6 +16,7 @@
# v15: model_util update # v15: model_util update
# v16: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0 # v16: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0
# v17: add fp16 gradient training (experimental) # v17: add fp16 gradient training (experimental)
# v18: add save_model_as option
import gc import gc
import time import time
@ -670,8 +671,21 @@ def train(args):
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
# モデル形式のオプション設定を確認する: # モデル形式のオプション設定を確認する:
# v11からDiffUsersから直接落としてくるのもOKただし認証がいるやつは未対応、またv11からDiffUsersも途中保存に対応した load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
if load_stable_diffusion_format:
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
src_diffusers_model_path = None
else:
src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path
if args.save_model_as is None:
save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors
else:
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# 乱数系列を初期化する # 乱数系列を初期化する
if args.seed is not None: if args.seed is not None:
@ -691,7 +705,9 @@ def train(args):
for cap_path in cap_paths: for cap_path in cap_paths:
if os.path.isfile(cap_path): if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding='utf-8') as f: with open(cap_path, "rt", encoding='utf-8') as f:
caption = f.readlines()[0].strip() lines = f.readlines()
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
break break
return caption return caption
@ -845,7 +861,7 @@ def train(args):
save_dtype = torch.float32 save_dtype = torch.float32
# モデルを読み込む # モデルを読み込む
if use_stable_diffusion_format: if load_stable_diffusion_format:
print("load StableDiffusion checkpoint") print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
else: else:
@ -1079,17 +1095,17 @@ def train(args):
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
print("saving checkpoint.") print("saving checkpoint.")
if use_stable_diffusion_format: if save_stable_diffusion_format:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(args.use_safetensors, epoch + 1)) ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1))
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype, vae) src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae)
else: else:
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder),
unwrap_model(unet), args.pretrained_model_name_or_path, unwrap_model(unet), src_diffusers_model_path,
use_safetensors=args.use_safetensors) use_safetensors=use_safetensors)
if args.save_state: if args.save_state:
print("saving state.") print("saving state.")
@ -1110,18 +1126,17 @@ def train(args):
if is_main_process: if is_main_process:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if use_stable_diffusion_format: if save_stable_diffusion_format:
ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(args.use_safetensors)) ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(use_safetensors))
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
args.pretrained_model_name_or_path, epoch, global_step, save_dtype, vae) src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae)
else: else:
# Create the pipeline using using the trained modules and save it.
print(f"save trained model as Diffusers to {args.output_dir}") print(f"save trained model as Diffusers to {args.output_dir}")
out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME) out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path, model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_diffusers_model_path,
use_safetensors=args.use_safetensors) use_safetensors=use_safetensors)
print("model saved.") print("model saved.")
@ -1149,8 +1164,12 @@ if __name__ == '__main__':
help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数") help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数")
parser.add_argument("--output_dir", type=str, default=None, parser.add_argument("--output_dir", type=str, default=None,
help="directory to output trained model / 学習後のモデル出力先ディレクトリ") help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存するStableDiffusion形式での保存時のみ有効")
parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
parser.add_argument("--use_safetensors", action='store_true', parser.add_argument("--use_safetensors", action='store_true',
help="use safetensors format to save / checkpoint、モデルをsafetensors形式で保存する") help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存するsave_model_as未指定時")
parser.add_argument("--save_every_n_epochs", type=int, default=None, parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_state", action="store_true", parser.add_argument("--save_state", action="store_true",
@ -1195,8 +1214,6 @@ if __name__ == '__main__':
parser.add_argument("--mixed_precision", type=str, default="no", parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存するStableDiffusion形式での保存時のみ有効")
parser.add_argument("--clip_skip", type=int, default=None, parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上") help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--logging_dir", type=str, default=None, parser.add_argument("--logging_dir", type=str, default=None,