Add support for new arguments:

- max_train_epochs
- max_data_loader_n_workers
Move some of the codeto  common gui library.
This commit is contained in:
bmaltais 2023-01-15 11:05:22 -05:00
parent 43116feda8
commit 6aed2bb402
18 changed files with 781 additions and 1293 deletions

201
LICENSE.md Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2022] [kohya-ss]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -99,6 +99,13 @@ Once you have created the LoRA network you can generate images via auto1111 by i
## Change history ## Change history
* 2023/01/15 (v20.2.1):
- Merging latest code update from kohya
- Added `--max_train_epochs` and `--max_data_loader_n_workers` option for each training script.
- If you specify the number of training epochs with `--max_train_epochs`, the number of steps is calculated from the number of epochs automatically.
- You can set the number of workers for DataLoader with `--max_data_loader_n_workers`, default is 8. The lower number may reduce the main memory usage and the time between epochs, but may cause slower dataloading (training).
- Fix loading some VAE or .safetensors as VAE is failed for `--vae` option. Thanks to Fannovel16!
- Add negative prompt scaling for `gen_img_diffusers.py` You can set another conditioning scale to the negative prompt with `--negative_scale` option, and `--nl` option for the prompt. Thanks to laksjdjf!
* 2023/01/11 (v20.2.0): * 2023/01/11 (v20.2.0):
- Add support for max token lenght - Add support for max token lenght
* 2023/01/10 (v20.1.1): * 2023/01/10 (v20.1.1):

View File

@ -20,6 +20,8 @@ from library.common_gui import (
color_aug_changed, color_aug_changed,
save_inference_file, save_inference_file,
set_pretrained_model_name_or_path_input, set_pretrained_model_name_or_path_input,
gradio_advanced_training,
run_cmd_advanced_training,
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -74,6 +76,8 @@ def save_configuration(
vae, vae,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -153,6 +157,8 @@ def open_configuration(
vae, vae,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -216,6 +222,8 @@ def train_model(
vae, vae,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -372,6 +380,11 @@ def train_model(
run_cmd += f' --output_name="{output_name}"' run_cmd += f' --output_name="{output_name}"'
if (int(max_token_length) > 75): if (int(max_token_length) > 75):
run_cmd += f' --max_token_length={max_token_length}' run_cmd += f' --max_token_length={max_token_length}'
if not max_train_epochs == '':
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
if not max_data_loader_n_workers == '':
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
print(run_cmd) print(run_cmd)
# Run the command # Run the command
@ -708,6 +721,16 @@ def dreambooth_tab(
], ],
value='75', value='75',
) )
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
# with gr.Row():
# max_train_epochs = gr.Textbox(
# label='Max train epoch',
# placeholder='(Optional) Override number of epoch',
# )
# max_data_loader_n_workers = gr.Textbox(
# label='Max num workers for DataLoader',
# placeholder='(Optional) Override number of epoch. Default: 8',
# )
with gr.Tab('Tools'): with gr.Tab('Tools'):
gr.Markdown( gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...' 'This section provide Dreambooth tools to help setup your dataset...'
@ -760,6 +783,8 @@ def dreambooth_tab(
vae, vae,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
] ]
button_open_config.click( button_open_config.click(

View File

@ -16,456 +16,326 @@ import library.train_util as train_util
def collate_fn(examples): def collate_fn(examples):
return examples[0] return examples[0]
def train(args): def train(args):
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents cache_latents = args.cache_latents
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
train_dataset = train_util.FineTuningDataset( train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
args.in_json, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.train_batch_size, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.train_data_dir, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
tokenizer, args.dataset_repeats, args.debug_dataset)
args.max_token_length, train_dataset.make_buckets()
args.shuffle_caption,
args.keep_tokens,
args.resolution,
args.enable_bucket,
args.min_bucket_reso,
args.max_bucket_reso,
args.flip_aug,
args.color_aug,
args.face_crop_aug_range,
args.random_crop,
args.dataset_repeats,
args.debug_dataset,
)
train_dataset.make_buckets()
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset) train_util.debug_dataset(train_dataset)
return return
if len(train_dataset) == 0: if len(train_dataset) == 0:
print( print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
'No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。' return
)
return
# acceleratorを準備する # acceleratorを準備する
print('prepare accelerator') print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
( text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
text_encoder,
vae,
unet,
load_stable_diffusion_format,
) = train_util.load_target_model(args, weight_dtype)
# verify load/save model formats # verify load/save model formats
if load_stable_diffusion_format: if load_stable_diffusion_format:
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
src_diffusers_model_path = None src_diffusers_model_path = None
else: else:
src_stable_diffusion_ckpt = None src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path src_diffusers_model_path = args.pretrained_model_name_or_path
if args.save_model_as is None: if args.save_model_as is None:
save_stable_diffusion_format = load_stable_diffusion_format save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors use_safetensors = args.use_safetensors
else: else:
save_stable_diffusion_format = ( save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
args.save_model_as.lower() == 'ckpt' use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
or args.save_model_as.lower() == 'safetensors'
)
use_safetensors = args.use_safetensors or (
'safetensors' in args.save_model_as.lower()
)
# Diffusers版のxformers使用フラグを設定する関数 # Diffusers版のxformers使用フラグを設定する関数
def set_diffusers_xformers_flag(model, valid): def set_diffusers_xformers_flag(model, valid):
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
# Recursively walk through all the children. # Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method # Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message # gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module): def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, 'set_use_memory_efficient_attention_xformers'): if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid) module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children(): for child in module.children():
fn_recursive_set_mem_eff(child) fn_recursive_set_mem_eff(child)
fn_recursive_set_mem_eff(model) fn_recursive_set_mem_eff(model)
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers: if args.diffusers_xformers:
print('Use xformers by Diffusers') print("Use xformers by Diffusers")
set_diffusers_xformers_flag(unet, True) set_diffusers_xformers_flag(unet, True)
else: else:
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers") print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False) set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset.cache_latents(vae) train_dataset.cache_latents(vae)
vae.to('cpu') vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
# 学習を準備する:モデルを適切な状態にする # 学習を準備する:モデルを適切な状態にする
training_models = [] training_models = []
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
if args.train_text_encoder:
print("enable text encoder training")
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable()
training_models.append(unet) training_models.append(text_encoder)
else:
if args.train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype)
print('enable text encoder training') text_encoder.requires_grad_(False) # text encoderは学習しない
if args.gradient_checkpointing: if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
training_models.append(text_encoder) text_encoder.train() # required for gradient_checkpointing
else: else:
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.eval()
text_encoder.requires_grad_(False) # text encoderは学習しない
if args.gradient_checkpointing: if not cache_latents:
text_encoder.gradient_checkpointing_enable() vae.requires_grad_(False)
text_encoder.train() # required for gradient_checkpointing vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
for m in training_models:
m.requires_grad_(True)
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
# 8-bit Adamを使う
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num examples / サンプル数: {train_dataset.num_train_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
if accelerator.is_main_process:
accelerator.init_trackers("finetuning")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
for m in training_models:
m.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else: else:
text_encoder.eval() target = noise
if not cache_latents: loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
for m in training_models: accelerator.backward(loss)
m.requires_grad_(True) if accelerator.sync_gradients:
params = [] params_to_clip = []
for m in training_models: for m in training_models:
params.extend(m.parameters()) params_to_clip.extend(m.parameters())
params_to_optimize = params accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
# 学習に必要なクラスを準備する optimizer.step()
print('prepare optimizer, data loader etc.') lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# 8-bit Adamを使う # Checks if the accelerator has performed an optimization step behind the scenes
if args.use_8bit_adam: if accelerator.sync_gradients:
try: progress_bar.update(1)
import bitsandbytes as bnb global_step += 1
except ImportError:
raise ImportError(
'No bitsand bytes / bitsandbytesがインストールされていないようです'
)
print('use 8-bit Adam optimizer')
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
# dataloaderを準備する loss_total += current_loss
# DataLoaderのプロセス数0はメインプロセスになる avr_loss = loss_total / (step+1)
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
train_dataloader = torch.utils.data.DataLoader( progress_bar.set_postfix(**logs)
train_dataset,
batch_size=1,
shuffle=False,
collate_fn=collate_fn,
num_workers=n_workers,
)
# lr schedulerを用意する if global_step >= args.max_train_steps:
lr_scheduler = diffusers.optimization.get_scheduler( break
args.lr_scheduler,
optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps
* args.gradient_accumulation_steps,
)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする if args.logging_dir is not None:
if args.full_fp16: logs = {"epoch_loss": loss_total / len(train_dataloader)}
assert ( accelerator.log(logs, step=epoch+1)
args.mixed_precision == 'fp16'
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print('enable full fp16 training.')
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい accelerator.wait_for_everyone()
if args.train_text_encoder:
(
unet,
text_encoder,
optimizer,
train_dataloader,
lr_scheduler,
) = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.save_every_n_epochs is not None:
if args.full_fp16: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.patch_accelerator_for_fp16_training(accelerator) train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
# resumeする is_main_process = accelerator.is_main_process
if args.resume is not None: if is_main_process:
print(f'resume training from state: {args.resume}') unet = unwrap_model(unet)
accelerator.load_state(args.resume) text_encoder = unwrap_model(text_encoder)
# epoch数を計算する accelerator.end_training()
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
num_train_epochs = math.ceil(
args.max_train_steps / num_update_steps_per_epoch
)
# 学習する if args.save_state:
total_batch_size = ( train_util.save_state_on_train_end(args, accelerator)
args.train_batch_size
* accelerator.num_processes
* args.gradient_accumulation_steps
)
print('running training / 学習開始')
print(f' num examples / サンプル数: {train_dataset.num_train_images}')
print(f' num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}')
print(f' num epochs / epoch数: {num_train_epochs}')
print(f' batch size per device / バッチサイズ: {args.train_batch_size}')
print(
f' total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}'
)
print(
f' gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}'
)
print(f' total optimization steps / 学習ステップ数: {args.max_train_steps}')
progress_bar = tqdm( del accelerator # この後メモリを使うのでこれは消す
range(args.max_train_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc='steps',
)
global_step = 0
noise_scheduler = DDPMScheduler( if is_main_process:
beta_start=0.00085, src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
beta_end=0.012, train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors,
beta_schedule='scaled_linear', save_dtype, epoch, global_step, text_encoder, unet, vae)
num_train_timesteps=1000, print("model saved.")
clip_sample=False,
)
if accelerator.is_main_process:
accelerator.init_trackers('finetuning')
for epoch in range(num_train_epochs):
print(f'epoch {epoch+1}/{num_train_epochs}')
for m in training_models:
m.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(
training_models[0]
): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad():
if 'latents' in batch and batch['latents'] is not None:
latents = batch['latents'].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(
batch['images'].to(dtype=weight_dtype)
).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch['input_ids'].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args,
input_ids,
tokenizer,
text_encoder,
None if not args.full_fp16 else weight_dtype,
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(b_size,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(
latents, noise, timesteps
)
# Predict the noise residual
noise_pred = unet(
noisy_latents, timesteps, encoder_hidden_states
).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(
latents, noise, timesteps
)
else:
target = noise
loss = torch.nn.functional.mse_loss(
noise_pred.float(), target.float(), reduction='mean'
)
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(
params_to_clip, 1.0
) # args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {
'loss': current_loss,
'lr': lr_scheduler.get_last_lr()[0],
}
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {'loss': avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {'epoch_loss': loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
src_path = (
src_stable_diffusion_ckpt
if save_stable_diffusion_format
else src_diffusers_model_path
)
train_util.save_sd_model_on_epoch_end(
args,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
src_path = (
src_stable_diffusion_ckpt
if save_stable_diffusion_format
else src_diffusers_model_path
)
train_util.save_sd_model_on_train_end(
args,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
global_step,
text_encoder,
unet,
vae,
)
print('model saved.')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True) train_util.add_dataset_arguments(parser, False, True)
train_util.add_training_arguments(parser, False) train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
parser.add_argument( parser.add_argument("--diffusers_xformers", action='store_true',
'--diffusers_xformers', help='use xformers by diffusers / Diffusersでxformersを使用する')
action='store_true', parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
help='use xformers by diffusers / Diffusersでxformersを使用する',
)
parser.add_argument(
'--train_text_encoder',
action='store_true',
help='train text encoder / text encoderも学習する',
)
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)

View File

@ -1,789 +0,0 @@
import gradio as gr
import json
import math
import os
import subprocess
import pathlib
import shutil
import argparse
from library.common_gui import (
get_folder_path,
get_file_path,
get_any_file_path,
get_saveasfile_path,
)
from library.utilities import utilities_tab
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
def save_configuration(
save_as,
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization,
train_dir,
image_folder,
output_dir,
logging_dir,
max_resolution,
min_bucket_reso,
max_bucket_reso,
batch_size,
flip_aug,
caption_metadata_filename,
latent_metadata_filename,
full_path,
learning_rate,
lr_scheduler,
lr_warmup,
dataset_repeats,
train_batch_size,
epoch,
save_every_n_epochs,
mixed_precision,
save_precision,
seed,
num_cpu_threads_per_process,
train_text_encoder,
create_caption,
create_buckets,
save_model_as,
caption_extension,
use_8bit_adam,
xformers,
clip_skip,
):
original_file_path = file_path
save_as_bool = True if save_as.get('label') == 'True' else False
if save_as_bool:
print('Save as...')
file_path = get_saveasfile_path(file_path)
else:
print('Save...')
if file_path == None or file_path == '':
file_path = get_saveasfile_path(file_path)
# print(file_path)
if file_path == None:
return original_file_path
# Return the values of the variables as a dictionary
variables = {
'pretrained_model_name_or_path': pretrained_model_name_or_path,
'v2': v2,
'v_parameterization': v_parameterization,
'train_dir': train_dir,
'image_folder': image_folder,
'output_dir': output_dir,
'logging_dir': logging_dir,
'max_resolution': max_resolution,
'min_bucket_reso': min_bucket_reso,
'max_bucket_reso': max_bucket_reso,
'batch_size': batch_size,
'flip_aug': flip_aug,
'caption_metadata_filename': caption_metadata_filename,
'latent_metadata_filename': latent_metadata_filename,
'full_path': full_path,
'learning_rate': learning_rate,
'lr_scheduler': lr_scheduler,
'lr_warmup': lr_warmup,
'dataset_repeats': dataset_repeats,
'train_batch_size': train_batch_size,
'epoch': epoch,
'save_every_n_epochs': save_every_n_epochs,
'mixed_precision': mixed_precision,
'save_precision': save_precision,
'seed': seed,
'num_cpu_threads_per_process': num_cpu_threads_per_process,
'train_text_encoder': train_text_encoder,
'create_buckets': create_buckets,
'create_caption': create_caption,
'save_model_as': save_model_as,
'caption_extension': caption_extension,
'use_8bit_adam': use_8bit_adam,
'xformers': xformers,
'clip_skip': clip_skip,
}
# Save the data to the selected file
with open(file_path, 'w') as file:
json.dump(variables, file)
return file_path
def open_config_file(
file_path,
pretrained_model_name_or_path,
v2,
v_parameterization,
train_dir,
image_folder,
output_dir,
logging_dir,
max_resolution,
min_bucket_reso,
max_bucket_reso,
batch_size,
flip_aug,
caption_metadata_filename,
latent_metadata_filename,
full_path,
learning_rate,
lr_scheduler,
lr_warmup,
dataset_repeats,
train_batch_size,
epoch,
save_every_n_epochs,
mixed_precision,
save_precision,
seed,
num_cpu_threads_per_process,
train_text_encoder,
create_caption,
create_buckets,
save_model_as,
caption_extension,
use_8bit_adam,
xformers,
clip_skip,
):
original_file_path = file_path
file_path = get_file_path(file_path)
if file_path != '' and file_path != None:
print(file_path)
# load variables from JSON file
with open(file_path, 'r') as f:
my_data = json.load(f)
else:
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
my_data = {}
# Return the values of the variables as a dictionary
return (
file_path,
my_data.get(
'pretrained_model_name_or_path', pretrained_model_name_or_path
),
my_data.get('v2', v2),
my_data.get('v_parameterization', v_parameterization),
my_data.get('train_dir', train_dir),
my_data.get('image_folder', image_folder),
my_data.get('output_dir', output_dir),
my_data.get('logging_dir', logging_dir),
my_data.get('max_resolution', max_resolution),
my_data.get('min_bucket_reso', min_bucket_reso),
my_data.get('max_bucket_reso', max_bucket_reso),
my_data.get('batch_size', batch_size),
my_data.get('flip_aug', flip_aug),
my_data.get('caption_metadata_filename', caption_metadata_filename),
my_data.get('latent_metadata_filename', latent_metadata_filename),
my_data.get('full_path', full_path),
my_data.get('learning_rate', learning_rate),
my_data.get('lr_scheduler', lr_scheduler),
my_data.get('lr_warmup', lr_warmup),
my_data.get('dataset_repeats', dataset_repeats),
my_data.get('train_batch_size', train_batch_size),
my_data.get('epoch', epoch),
my_data.get('save_every_n_epochs', save_every_n_epochs),
my_data.get('mixed_precision', mixed_precision),
my_data.get('save_precision', save_precision),
my_data.get('seed', seed),
my_data.get(
'num_cpu_threads_per_process', num_cpu_threads_per_process
),
my_data.get('train_text_encoder', train_text_encoder),
my_data.get('create_buckets', create_buckets),
my_data.get('create_caption', create_caption),
my_data.get('save_model_as', save_model_as),
my_data.get('caption_extension', caption_extension),
my_data.get('use_8bit_adam', use_8bit_adam),
my_data.get('xformers', xformers),
my_data.get('clip_skip', clip_skip),
)
def train_model(
pretrained_model_name_or_path,
v2,
v_parameterization,
train_dir,
image_folder,
output_dir,
logging_dir,
max_resolution,
min_bucket_reso,
max_bucket_reso,
batch_size,
flip_aug,
caption_metadata_filename,
latent_metadata_filename,
full_path,
learning_rate,
lr_scheduler,
lr_warmup,
dataset_repeats,
train_batch_size,
epoch,
save_every_n_epochs,
mixed_precision,
save_precision,
seed,
num_cpu_threads_per_process,
train_text_encoder,
generate_caption_database,
generate_image_buckets,
save_model_as,
caption_extension,
use_8bit_adam,
xformers,
clip_skip,
):
def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required
if v2 and v_parameterization:
print(f'Saving v2-inference-v.yaml as {output_dir}/last.yaml')
shutil.copy(
f'./v2_inference/v2-inference-v.yaml',
f'{output_dir}/last.yaml',
)
elif v2:
print(f'Saving v2-inference.yaml as {output_dir}/last.yaml')
shutil.copy(
f'./v2_inference/v2-inference.yaml',
f'{output_dir}/last.yaml',
)
# create caption json file
if generate_caption_database:
if not os.path.exists(train_dir):
os.mkdir(train_dir)
for root, dirs, files in os.walk(image_folder):
for dir in dirs:
print(os.path.join(root, dir))
run_cmd = (
f'./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py'
)
if caption_extension == '':
run_cmd += f' --caption_extension=".txt"'
else:
run_cmd += f' --caption_extension={caption_extension}'
run_cmd += f' "{os.path.join(root, dir)}"'
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
if full_path:
run_cmd += f' --full_path'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
# create images buckets
if generate_image_buckets:
run_cmd = (
f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py'
)
run_cmd += f' "crap"'
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
run_cmd += f' "{train_dir}/{latent_metadata_filename}"'
run_cmd += f' "{pretrained_model_name_or_path}"'
run_cmd += f' --batch_size={batch_size}'
run_cmd += f' --max_resolution={max_resolution}'
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
run_cmd += f' --mixed_precision={mixed_precision}'
if flip_aug:
run_cmd += f' --flip_aug'
if full_path:
run_cmd += f' --full_path'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
image_num = 0
for root, dirs, files in os.walk(image_folder):
for dir in dirs:
image_num += len(
[f for f in os.listdir(os.path.join(root, dir)) if f.endswith('.npz')]
)
print(f'image_num = {image_num}')
repeats = int(image_num) * int(dataset_repeats)
print(f'repeats = {str(repeats)}')
# calculate max_train_steps
max_train_steps = int(
math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
)
# Divide by two because flip augmentation create two copied of the source images
if flip_aug:
max_train_steps = int(math.ceil(float(max_train_steps) / 2))
print(f'max_train_steps = {max_train_steps}')
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
print(f'lr_warmup_steps = {lr_warmup_steps}')
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
if v2:
run_cmd += ' --v2'
if v_parameterization:
run_cmd += ' --v_parameterization'
if train_text_encoder:
run_cmd += ' --train_text_encoder'
if use_8bit_adam:
run_cmd += f' --use_8bit_adam'
if xformers:
run_cmd += f' --xformers'
run_cmd += (
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
)
run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"'
run_cmd += f' --train_data_dir="{image_folder}"'
run_cmd += f' --output_dir="{output_dir}"'
if not logging_dir == '':
run_cmd += f' --logging_dir="{logging_dir}"'
run_cmd += f' --train_batch_size={train_batch_size}'
run_cmd += f' --dataset_repeats={dataset_repeats}'
run_cmd += f' --learning_rate={learning_rate}'
run_cmd += f' --lr_scheduler={lr_scheduler}'
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
run_cmd += f' --max_train_steps={max_train_steps}'
run_cmd += f' --mixed_precision={mixed_precision}'
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
run_cmd += f' --seed={seed}'
run_cmd += f' --save_precision={save_precision}'
if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}'
if int(clip_skip) > 1:
run_cmd += f' --clip_skip={str(clip_skip)}'
print(run_cmd)
# Run the command
subprocess.run(run_cmd)
# check if output_dir/last is a folder... therefore it is a diffuser model
last_dir = pathlib.Path(f'{output_dir}/last')
if not last_dir.is_dir():
# Copy inference model for v2 if required
save_inference_file(output_dir, v2, v_parameterization)
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
# define a list of substrings to search for
substrings_v2 = [
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
]
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
if str(value) in substrings_v2:
print('SD v2 model detected. Setting --v2 parameter')
v2 = True
v_parameterization = False
return value, v2, v_parameterization
# define a list of substrings to search for v-objective
substrings_v_parameterization = [
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
]
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
if str(value) in substrings_v_parameterization:
print(
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
)
v2 = True
v_parameterization = True
return value, v2, v_parameterization
# define a list of substrings to v1.x
substrings_v1_model = [
'CompVis/stable-diffusion-v1-4',
'runwayml/stable-diffusion-v1-5',
]
if str(value) in substrings_v1_model:
v2 = False
v_parameterization = False
return value, v2, v_parameterization
if value == 'custom':
value = ''
v2 = False
v_parameterization = False
return value, v2, v_parameterization
def remove_doublequote(file_path):
if file_path != None:
file_path = file_path.replace('"', '')
return file_path
def UI(username, password):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Finetune'):
finetune_tab()
with gr.Tab('Utilities'):
utilities_tab(enable_dreambooth_tab=False)
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
def finetune_tab():
dummy_ft_true = gr.Label(value=True, visible=False)
dummy_ft_false = gr.Label(value=False, visible=False)
gr.Markdown('Train a custom model using kohya finetune python code...')
with gr.Accordion('Configuration file', open=False):
with gr.Row():
button_open_config = gr.Button(
f'Open {folder_symbol}', elem_id='open_folder'
)
button_save_config = gr.Button(
f'Save {save_style_symbol}', elem_id='open_folder'
)
button_save_as_config = gr.Button(
f'Save as... {save_style_symbol}',
elem_id='open_folder',
)
config_file_name = gr.Textbox(
label='', placeholder='type file path or use buttons...'
)
config_file_name.change(
remove_doublequote,
inputs=[config_file_name],
outputs=[config_file_name],
)
with gr.Tab('Source model'):
# Define the input elements
with gr.Row():
pretrained_model_name_or_path_input = gr.Textbox(
label='Pretrained model name or path',
placeholder='enter the path to custom model or name of pretrained model',
)
pretrained_model_name_or_path_file = gr.Button(
document_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_file.click(
get_any_file_path,
inputs=pretrained_model_name_or_path_input,
outputs=pretrained_model_name_or_path_input,
)
pretrained_model_name_or_path_folder = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_folder.click(
get_folder_path,
inputs=pretrained_model_name_or_path_input,
outputs=pretrained_model_name_or_path_input,
)
model_list = gr.Dropdown(
label='(Optional) Model Quick Pick',
choices=[
'custom',
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
],
)
save_model_as_dropdown = gr.Dropdown(
label='Save trained model as',
choices=[
'same as source model',
'ckpt',
'diffusers',
'diffusers_safetensors',
'safetensors',
],
value='same as source model',
)
with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True)
v_parameterization_input = gr.Checkbox(
label='v_parameterization', value=False
)
model_list.change(
set_pretrained_model_name_or_path_input,
inputs=[model_list, v2_input, v_parameterization_input],
outputs=[
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
],
)
with gr.Tab('Folders'):
with gr.Row():
train_dir_input = gr.Textbox(
label='Training config folder',
placeholder='folder where the training configuration files will be saved',
)
train_dir_folder = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
train_dir_folder.click(get_folder_path, outputs=train_dir_input)
image_folder_input = gr.Textbox(
label='Training Image folder',
placeholder='folder where the training images are located',
)
image_folder_input_folder = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
image_folder_input_folder.click(
get_folder_path, outputs=image_folder_input
)
with gr.Row():
output_dir_input = gr.Textbox(
label='Output folder',
placeholder='folder where the model will be saved',
)
output_dir_input_folder = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
output_dir_input_folder.click(
get_folder_path, outputs=output_dir_input
)
logging_dir_input = gr.Textbox(
label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder',
)
logging_dir_input_folder = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
logging_dir_input_folder.click(
get_folder_path, outputs=logging_dir_input
)
train_dir_input.change(
remove_doublequote,
inputs=[train_dir_input],
outputs=[train_dir_input],
)
image_folder_input.change(
remove_doublequote,
inputs=[image_folder_input],
outputs=[image_folder_input],
)
output_dir_input.change(
remove_doublequote,
inputs=[output_dir_input],
outputs=[output_dir_input],
)
with gr.Tab('Dataset preparation'):
with gr.Row():
max_resolution_input = gr.Textbox(
label='Resolution (width,height)', value='512,512'
)
min_bucket_reso = gr.Textbox(
label='Min bucket resolution', value='256'
)
max_bucket_reso = gr.Textbox(
label='Max bucket resolution', value='1024'
)
batch_size = gr.Textbox(label='Batch size', value='1')
with gr.Accordion('Advanced parameters', open=False):
with gr.Row():
caption_metadata_filename = gr.Textbox(
label='Caption metadata filename', value='meta_cap.json'
)
latent_metadata_filename = gr.Textbox(
label='Latent metadata filename', value='meta_lat.json'
)
full_path = gr.Checkbox(label='Use full path', value=True)
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
with gr.Tab('Training parameters'):
with gr.Row():
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
lr_scheduler_input = gr.Dropdown(
label='LR Scheduler',
choices=[
'constant',
'constant_with_warmup',
'cosine',
'cosine_with_restarts',
'linear',
'polynomial',
],
value='constant',
)
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
with gr.Row():
dataset_repeats_input = gr.Textbox(
label='Dataset repeats', value=40
)
train_batch_size_input = gr.Slider(
minimum=1,
maximum=32,
label='Train batch size',
value=1,
step=1,
)
epoch_input = gr.Textbox(label='Epoch', value=1)
save_every_n_epochs_input = gr.Textbox(
label='Save every N epochs', value=1
)
with gr.Row():
mixed_precision_input = gr.Dropdown(
label='Mixed precision',
choices=[
'no',
'fp16',
'bf16',
],
value='fp16',
)
save_precision_input = gr.Dropdown(
label='Save precision',
choices=[
'float',
'fp16',
'bf16',
],
value='fp16',
)
num_cpu_threads_per_process_input = gr.Slider(
minimum=1,
maximum=os.cpu_count(),
step=1,
label='Number of CPU threads per process',
value=os.cpu_count(),
)
seed_input = gr.Textbox(label='Seed', value=1234)
with gr.Row():
caption_extention_input = gr.Textbox(
label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .txt',
)
train_text_encoder_input = gr.Checkbox(
label='Train text encoder', value=True
)
with gr.Accordion('Advanced parameters', open=False):
with gr.Row():
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
xformers = gr.Checkbox(label='Use xformers', value=True)
clip_skip = gr.Slider(
label='Clip skip', value='1', minimum=1, maximum=12, step=1
)
with gr.Box():
with gr.Row():
create_caption = gr.Checkbox(
label='Generate caption metadata', value=True
)
create_buckets = gr.Checkbox(
label='Generate image buckets metadata', value=True
)
button_run = gr.Button('Train model')
settings_list = [
pretrained_model_name_or_path_input,
v2_input,
v_parameterization_input,
train_dir_input,
image_folder_input,
output_dir_input,
logging_dir_input,
max_resolution_input,
min_bucket_reso,
max_bucket_reso,
batch_size,
flip_aug,
caption_metadata_filename,
latent_metadata_filename,
full_path,
learning_rate_input,
lr_scheduler_input,
lr_warmup_input,
dataset_repeats_input,
train_batch_size_input,
epoch_input,
save_every_n_epochs_input,
mixed_precision_input,
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
train_text_encoder_input,
create_caption,
create_buckets,
save_model_as_dropdown,
caption_extention_input,
use_8bit_adam,
xformers,
clip_skip,
]
button_run.click(train_model, inputs=settings_list)
button_open_config.click(
open_config_file,
inputs=[config_file_name] + settings_list,
outputs=[config_file_name] + settings_list,
)
button_save_config.click(
save_configuration,
inputs=[dummy_ft_false, config_file_name] + settings_list,
outputs=[config_file_name],
)
button_save_as_config.click(
save_configuration,
inputs=[dummy_ft_true, config_file_name] + settings_list,
outputs=[config_file_name],
)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument(
'--username', type=str, default='', help='Username for authentication'
)
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
args = parser.parse_args()
UI(username=args.username, password=args.password)

View File

@ -13,6 +13,7 @@ from library.common_gui import (
get_saveasfile_path, get_saveasfile_path,
save_inference_file, save_inference_file,
set_pretrained_model_name_or_path_input, set_pretrained_model_name_or_path_input,
gradio_advanced_training,run_cmd_advanced_training
) )
from library.utilities import utilities_tab from library.utilities import utilities_tab
@ -67,6 +68,8 @@ def save_configuration(
shuffle_caption, shuffle_caption,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -150,6 +153,8 @@ def open_config_file(
shuffle_caption, shuffle_caption,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -218,6 +223,8 @@ def train_model(
shuffle_caption, shuffle_caption,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
): ):
# create caption json file # create caption json file
if generate_caption_database: if generate_caption_database:
@ -336,6 +343,7 @@ def train_model(
run_cmd += f' --output_name="{output_name}"' run_cmd += f' --output_name="{output_name}"'
if (int(max_token_length) > 75): if (int(max_token_length) > 75):
run_cmd += f' --max_token_length={max_token_length}' run_cmd += f' --max_token_length={max_token_length}'
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
print(run_cmd) print(run_cmd)
# Run the command # Run the command
@ -656,6 +664,7 @@ def finetune_tab():
], ],
value='75', value='75',
) )
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
with gr.Box(): with gr.Box():
with gr.Row(): with gr.Row():
create_caption = gr.Checkbox( create_caption = gr.Checkbox(
@ -710,6 +719,8 @@ def finetune_tab():
shuffle_caption, shuffle_caption,
output_name, output_name,
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
] ]
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)

View File

@ -1,38 +1,3 @@
# txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc...
# v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix
# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working
# v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode
# v5: fix clip_sample=True for scheduler, add VGG guidance
# v6: refactor to use model util, load VAE without vae folder, support safe tensors
# v7: add use_original_file_name and iter_same_seed option, change vgg16 guide input image size,
# Diffusers 0.10.0 (support new schedulers (dpm_2, dpm_2_a, heun, dpmsingle), supports all scheduler in v-prediction)
# v8: accept wildcard for ckpt name (when only one file is matched), fix a bug app crushes because PIL image doesn't have filename attr sometimes,
# v9: sort file names, fix an issue in img2img when prompt from metadata with images_per_prompt>1
# v10: fix app crashes when different image size in prompts
# Copyright 2022 kohya_ss @kohya_ss
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# license of included scripts:
# FlashAttention: based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
# Diffusers (model conversion, CLIP guided stable diffusion, schedulers etc.):
# ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE
""" """
VGG( VGG(
(features): Sequential( (features): Sequential(
@ -81,11 +46,13 @@ VGG(
) )
""" """
import json
from typing import List, Optional, Union from typing import List, Optional, Union
import glob import glob
import importlib import importlib
import inspect import inspect
import time import time
import zipfile
from diffusers.utils import deprecate from diffusers.utils import deprecate
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
import argparse import argparse
@ -517,7 +484,7 @@ class PipelineLike():
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers) self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
# region xformersとか使う部分独自に書き換えるので関係なし # region xformersとか使う部分独自に書き換えるので関係なし
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self):
r""" r"""
Enable memory efficient attention as implemented in xformers. Enable memory efficient attention as implemented in xformers.
@ -590,6 +557,7 @@ class PipelineLike():
width: int = 512, width: int = 512,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_scale: float = None,
strength: float = 0.8, strength: float = 0.8,
# num_images_per_prompt: Optional[int] = 1, # num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0, eta: float = 0.0,
@ -708,6 +676,11 @@ class PipelineLike():
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None:
print(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if negative_prompt is None: if negative_prompt is None:
negative_prompt = [""] * batch_size negative_prompt = [""] * batch_size
@ -729,8 +702,21 @@ class PipelineLike():
**kwargs, **kwargs,
) )
if negative_scale is not None:
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
pipe=self,
prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須
uncond_prompt=[""]*batch_size,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
**kwargs,
)
if do_classifier_free_guidance: if do_classifier_free_guidance:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) if negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
# CLIP guidanceで使用するembeddingsを取得する # CLIP guidanceで使用するembeddingsを取得する
if self.clip_guidance_scale > 0: if self.clip_guidance_scale > 0:
@ -861,22 +847,28 @@ class PipelineLike():
if accepts_eta: if accepts_eta:
extra_step_kwargs["eta"] = eta extra_step_kwargs["eta"] = eta
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((2, 1, 1, 1)) if do_classifier_free_guidance else latents latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) if negative_scale is None:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(num_latent_input) # uncond is real uncond
noise_pred = noise_pred_uncond + guidance_scale * \
(noise_pred_text - noise_pred_uncond) - negative_scale * (noise_pred_negative - noise_pred_uncond)
# perform clip guidance # perform clip guidance
if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0:
text_embeddings_for_guidance = (text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings) text_embeddings_for_guidance = (text_embeddings.chunk(num_latent_input)[
1] if do_classifier_free_guidance else text_embeddings)
if self.clip_guidance_scale > 0: if self.clip_guidance_scale > 0:
noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred, noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred,
@ -1982,26 +1974,50 @@ def main(args):
vgg16_model.to(dtype).to(device) vgg16_model.to(dtype).to(device)
# networkを組み込む # networkを組み込む
if args.network_module is not None: if args.network_module:
# assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません" networks = []
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
print("import network module:", args.network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_module = importlib.import_module(args.network_module) network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs) net_kwargs = {}
if network is None: if args.network_args and i < len(args.network_args):
return network_args = args.network_args[i]
# TODO escape special chars
network_args = network_args.split(";")
for net_arg in network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
print("load network weights from:", args.network_weights) network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
network.load_weights(args.network_weights) if network is None:
return
network.apply_to(text_encoder, unet) if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if args.opt_channels_last: if os.path.splitext(network_weight)[1] == '.safetensors':
network.to(memory_format=torch.channels_last) from safetensors.torch import safe_open
network.to(dtype).to(device) with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network.load_weights(network_weight)
network.apply_to(text_encoder, unet)
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
networks.append(network)
else: else:
network = None networks = []
if args.opt_channels_last: if args.opt_channels_last:
print(f"set optimizing: channels last") print(f"set optimizing: channels last")
@ -2010,8 +2026,9 @@ def main(args):
unet.to(memory_format=torch.channels_last) unet.to(memory_format=torch.channels_last)
if clip_model is not None: if clip_model is not None:
clip_model.to(memory_format=torch.channels_last) clip_model.to(memory_format=torch.channels_last)
if network is not None: if networks:
network.to(memory_format=torch.channels_last) for network in networks:
network.to(memory_format=torch.channels_last)
if vgg16_model is not None: if vgg16_model is not None:
vgg16_model.to(memory_format=torch.channels_last) vgg16_model.to(memory_format=torch.channels_last)
@ -2154,12 +2171,12 @@ def main(args):
# 1st stageのバッチを作成して呼び出す # 1st stageのバッチを作成して呼び出す
print("process 1st stage1") print("process 1st stage1")
batch_1st = [] batch_1st = []
for params1, (width, height, steps, scale, strength) in batch: for params1, (width, height, steps, scale, negative_scale, strength) in batch:
width_1st = int(width * args.highres_fix_scale + .5) width_1st = int(width * args.highres_fix_scale + .5)
height_1st = int(height * args.highres_fix_scale + .5) height_1st = int(height * args.highres_fix_scale + .5)
width_1st = width_1st - width_1st % 32 width_1st = width_1st - width_1st % 32
height_1st = height_1st - height_1st % 32 height_1st = height_1st - height_1st % 32
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, strength))) batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
images_1st = process_batch(batch_1st, True, True) images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する # 2nd stageのバッチを作成して以下処理する
@ -2171,7 +2188,8 @@ def main(args):
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2)) batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
batch = batch_2nd batch = batch_2nd
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width, height, steps, scale, strength) = batch[0] (step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
height, steps, scale, negative_scale, strength) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
prompts = [] prompts = []
@ -2247,7 +2265,7 @@ def main(args):
guide_images = guide_images[0] guide_images = guide_images[0]
# generate # generate
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, strength, latents=start_code, images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
if highres_1st and not args.highres_fix_save_1st: if highres_1st and not args.highres_fix_save_1st:
return images return images
@ -2264,6 +2282,8 @@ def main(args):
metadata.add_text("scale", str(scale)) metadata.add_text("scale", str(scale))
if negative_prompt is not None: if negative_prompt is not None:
metadata.add_text("negative-prompt", negative_prompt) metadata.add_text("negative-prompt", negative_prompt)
if negative_scale is not None:
metadata.add_text("negative-scale", str(negative_scale))
if clip_prompt is not None: if clip_prompt is not None:
metadata.add_text("clip-prompt", clip_prompt) metadata.add_text("clip-prompt", clip_prompt)
@ -2316,6 +2336,7 @@ def main(args):
width = args.W width = args.W
height = args.H height = args.H
scale = args.scale scale = args.scale
negative_scale = args.negative_scale
steps = args.steps steps = args.steps
seeds = None seeds = None
strength = 0.8 if args.strength is None else args.strength strength = 0.8 if args.strength is None else args.strength
@ -2358,6 +2379,15 @@ def main(args):
print(f"scale: {scale}") print(f"scale: {scale}")
continue continue
m = re.match(r'nl ([\d\.]+|none|None)', parg, re.IGNORECASE)
if m: # negative scale
if m.group(1).lower() == 'none':
negative_scale = None
else:
negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}")
continue
m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE) m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE)
if m: # strength if m: # strength
strength = float(m.group(1)) strength = float(m.group(1))
@ -2420,8 +2450,9 @@ def main(args):
print("Use previous image as guide image.") print("Use previous image as guide image.")
guide_image = prev_image guide_image = prev_image
# TODO named tupleか何かにする
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
(width, height, steps, scale, strength)) (width, height, steps, scale, negative_scale, strength))
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
process_batch(batch_data, highres_fix) process_batch(batch_data, highres_fix)
batch_data.clear() batch_data.clear()
@ -2481,19 +2512,24 @@ if __name__ == '__main__':
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
parser.add_argument("--seed", type=int, default=None, parser.add_argument("--seed", type=int, default=None,
help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed") help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed")
parser.add_argument("--iter_same_seed", action='store_true', help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使うプロンプト間の差異の比較用') parser.add_argument("--iter_same_seed", action='store_true',
help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使うプロンプト間の差異の比較用')
parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する') parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する')
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
parser.add_argument("--diffusers_xformers", action='store_true', parser.add_argument("--diffusers_xformers", action='store_true',
help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用するHypernetwork利用不可') help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用するHypernetwork利用不可')
parser.add_argument("--opt_channels_last", action='store_true', parser.add_argument("--opt_channels_last", action='store_true',
help='set channels last option to model / モデルにchannels lastを指定し最適化する') help='set channels last option to model / モデルにchannles lastを指定し最適化する')
parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') parser.add_argument("--network_module", type=str, default=None, nargs='*',
parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み') help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率') parser.add_argument("--network_weights", type=str, default=None, nargs='*',
parser.add_argument("--network_dim", type=int, default=None, help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
parser.add_argument("--max_embeddings_multiples", type=int, default=None, parser.add_argument("--max_embeddings_multiples", type=int, default=None,
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる') help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
@ -2512,6 +2548,8 @@ if __name__ == '__main__':
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数") help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
parser.add_argument("--highres_fix_save_1st", action='store_true', parser.add_argument("--highres_fix_save_1st", action='store_true',
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
parser.add_argument("--negative_scale", type=float, default=None,
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -262,3 +262,30 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
v_parameterization = False v_parameterization = False
return value, v2, v_parameterization return value, v2, v_parameterization
###
### Gradio common GUI section
###
def gradio_advanced_training():
with gr.Row():
max_train_epochs = gr.Textbox(
label='Max train epoch',
placeholder='(Optional) Override number of epoch',
)
max_data_loader_n_workers = gr.Textbox(
label='Max num workers for DataLoader',
placeholder='(Optional) Override number of epoch. Default: 8',
)
return max_train_epochs, max_data_loader_n_workers
def run_cmd_advanced_training(**kwargs):
run_cmd = ''
max_train_epochs = kwargs.get('max_train_epochs', '')
max_data_loader_n_workers = kwargs.get('max_data_loader_n_workers', '')
if not max_train_epochs == '':
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
if not max_data_loader_n_workers == '':
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
return run_cmd

View File

@ -886,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
info = vae.load_state_dict(converted_vae_checkpoint) info = vae.load_state_dict(converted_vae_checkpoint)
print("loadint vae:", info) print("loading vae:", info)
# convert text_model # convert text_model
if v2: if v2:
@ -1105,12 +1105,12 @@ def load_vae(vae_id, dtype):
if vae_id.endswith(".bin"): if vae_id.endswith(".bin"):
# SD 1.5 VAE on Huggingface # SD 1.5 VAE on Huggingface
vae_sd = torch.load(vae_id, map_location="cpu") converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
converted_vae_checkpoint = vae_sd
else: else:
# StableDiffusion # StableDiffusion
vae_model = torch.load(vae_id, map_location="cpu") vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
vae_sd = vae_model['state_dict'] else torch.load(vae_id, map_location="cpu"))
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
# vae only or full model # vae only or full model
full_model = False full_model = False
@ -1132,7 +1132,6 @@ def load_vae(vae_id, dtype):
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
return vae return vae
# endregion # endregion

View File

@ -715,7 +715,10 @@ def debug_dataset(train_dataset):
def glob_images(dir, base): def glob_images(dir, base):
img_paths = [] img_paths = []
for ext in IMAGE_EXTENSIONS: for ext in IMAGE_EXTENSIONS:
img_paths.extend(glob.glob(os.path.join(dir, base + ext))) if base == '*':
img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext))))
return img_paths return img_paths
# endregion # endregion
@ -744,6 +747,20 @@ def exists(val):
def default(val, d): def default(val, d):
return val if exists(val) else d return val if exists(val) else d
def model_hash(filename):
try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
return m.hexdigest()[0:8]
except FileNotFoundError:
return 'NOFILE'
# flash attention forwards and backwards # flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135 # https://arxiv.org/abs/2205.14135
@ -1030,6 +1047,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします")
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true", parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする") help="enable gradient checkpointing / grandient checkpointingを有効にする")

View File

@ -19,7 +19,7 @@ from library.common_gui import (
get_saveasfile_path, get_saveasfile_path,
color_aug_changed, color_aug_changed,
save_inference_file, save_inference_file,
set_pretrained_model_name_or_path_input, set_pretrained_model_name_or_path_input, gradio_advanced_training,run_cmd_advanced_training,
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -81,6 +81,8 @@ def save_configuration(
output_name, output_name,
model_list, model_list,
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -165,6 +167,8 @@ def open_configuration(
output_name, output_name,
model_list, model_list,
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -233,6 +237,8 @@ def train_model(
output_name, output_name,
model_list, # Keep this. Yes, it is unused here but required given the common list used model_list, # Keep this. Yes, it is unused here but required given the common list used
max_token_length, max_token_length,
max_train_epochs,
max_data_loader_n_workers,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -410,6 +416,7 @@ def train_model(
run_cmd += f' --output_name="{output_name}"' run_cmd += f' --output_name="{output_name}"'
if (int(max_token_length) > 75): if (int(max_token_length) > 75):
run_cmd += f' --max_token_length={max_token_length}' run_cmd += f' --max_token_length={max_token_length}'
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
print(run_cmd) print(run_cmd)
# Run the command # Run the command
@ -795,6 +802,7 @@ def lora_tab(
], ],
value='75', value='75',
) )
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
with gr.Tab('Tools'): with gr.Tab('Tools'):
gr.Markdown( gr.Markdown(
@ -854,7 +862,9 @@ def lora_tab(
mem_eff_attn, mem_eff_attn,
output_name, output_name,
model_list, model_list,
max_token_length max_token_length,
max_train_epochs,
max_data_loader_n_workers,
] ]
button_open_config.click( button_open_config.click(

View File

@ -135,7 +135,7 @@ def svd(args):
if dir_name and not os.path.exists(dir_name): if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
lora_network_o.save_weights(args.save_to, save_dtype) lora_network_o.save_weights(args.save_to, save_dtype, {})
print(f"LoRA weights are saved to: {args.save_to}") print(f"LoRA weights are saved to: {args.save_to}")

View File

@ -92,7 +92,7 @@ class LoRANetwork(torch.nn.Module):
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file from safetensors.torch import load_file, safe_open
self.weights_sd = load_file(file) self.weights_sd = load_file(file)
else: else:
self.weights_sd = torch.load(file, map_location='cpu') self.weights_sd = torch.load(file, map_location='cpu')
@ -174,7 +174,10 @@ class LoRANetwork(torch.nn.Module):
def get_trainable_params(self): def get_trainable_params(self):
return self.parameters() return self.parameters()
def save_weights(self, file, dtype): def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict() state_dict = self.state_dict()
if dtype is not None: if dtype is not None:
@ -185,6 +188,6 @@ class LoRANetwork(torch.nn.Module):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import save_file from safetensors.torch import save_file
save_file(state_dict, file) save_file(state_dict, file, metadata)
else: else:
torch.save(state_dict, file) torch.save(state_dict, file)

View File

@ -134,10 +134,15 @@ def train(args):
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler( lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)

View File

@ -126,10 +126,15 @@ def train(args):
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler( lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
@ -194,9 +199,62 @@ def train(args):
print(f" num epochs / epoch数: {num_train_epochs}") print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
metadata = {
"ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data
"ss_num_reg_images": train_dataset.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs,
"ss_batch_size_per_device": args.train_batch_size,
"ss_total_batch_size": total_batch_size,
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
"ss_max_train_steps": args.max_train_steps,
"ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2),
"ss_resolution": args.resolution,
"ss_clip_skip": args.clip_skip,
"ss_max_token_length": args.max_token_length,
"ss_color_aug": bool(args.color_aug),
"ss_flip_aug": bool(args.flip_aug),
"ss_random_crop": bool(args.random_crop),
"ss_shuffle_caption": bool(args.shuffle_caption),
"ss_cache_latents": bool(args.cache_latents),
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
"ss_max_bucket_reso": args.max_bucket_reso,
"ss_seed": args.seed
}
# uncomment if another network is added
# for key, value in net_kwargs.items():
# metadata["ss_arg_" + key] = value
if args.pretrained_model_name_or_path is not None:
sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name
if args.vae is not None:
vae_name = args.vae
if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name
metadata = {k: str(v) for k, v in metadata.items()}
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0 global_step = 0
@ -208,6 +266,7 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
metadata["ss_epoch"] = str(epoch+1)
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
@ -296,7 +355,7 @@ def train(args):
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}") print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype) unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
def remove_old_func(old_epoch_no): def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
@ -311,6 +370,8 @@ def train(args):
# end of epoch # end of epoch
metadata["ss_epoch"] = str(num_train_epochs)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:
network = unwrap_model(network) network = unwrap_model(network)
@ -330,7 +391,7 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}") print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype) network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
print("model saved.") print("model saved.")
@ -341,6 +402,7 @@ if __name__ == '__main__':
train_util.add_dataset_arguments(parser, True, True) train_util.add_dataset_arguments(parser, True, True)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt") help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt")