Add support for new arguments:
- max_train_epochs - max_data_loader_n_workers Move some of the codeto common gui library.
This commit is contained in:
parent
43116feda8
commit
6aed2bb402
201
LICENSE.md
Normal file
201
LICENSE.md
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [2022] [kohya-ss]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
@ -99,6 +99,13 @@ Once you have created the LoRA network you can generate images via auto1111 by i
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
|
* 2023/01/15 (v20.2.1):
|
||||||
|
- Merging latest code update from kohya
|
||||||
|
- Added `--max_train_epochs` and `--max_data_loader_n_workers` option for each training script.
|
||||||
|
- If you specify the number of training epochs with `--max_train_epochs`, the number of steps is calculated from the number of epochs automatically.
|
||||||
|
- You can set the number of workers for DataLoader with `--max_data_loader_n_workers`, default is 8. The lower number may reduce the main memory usage and the time between epochs, but may cause slower dataloading (training).
|
||||||
|
- Fix loading some VAE or .safetensors as VAE is failed for `--vae` option. Thanks to Fannovel16!
|
||||||
|
- Add negative prompt scaling for `gen_img_diffusers.py` You can set another conditioning scale to the negative prompt with `--negative_scale` option, and `--nl` option for the prompt. Thanks to laksjdjf!
|
||||||
* 2023/01/11 (v20.2.0):
|
* 2023/01/11 (v20.2.0):
|
||||||
- Add support for max token lenght
|
- Add support for max token lenght
|
||||||
* 2023/01/10 (v20.1.1):
|
* 2023/01/10 (v20.1.1):
|
||||||
|
@ -20,6 +20,8 @@ from library.common_gui import (
|
|||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
save_inference_file,
|
save_inference_file,
|
||||||
set_pretrained_model_name_or_path_input,
|
set_pretrained_model_name_or_path_input,
|
||||||
|
gradio_advanced_training,
|
||||||
|
run_cmd_advanced_training,
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
@ -74,6 +76,8 @@ def save_configuration(
|
|||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -153,6 +157,8 @@ def open_configuration(
|
|||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -216,6 +222,8 @@ def train_model(
|
|||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -372,6 +380,11 @@ def train_model(
|
|||||||
run_cmd += f' --output_name="{output_name}"'
|
run_cmd += f' --output_name="{output_name}"'
|
||||||
if (int(max_token_length) > 75):
|
if (int(max_token_length) > 75):
|
||||||
run_cmd += f' --max_token_length={max_token_length}'
|
run_cmd += f' --max_token_length={max_token_length}'
|
||||||
|
if not max_train_epochs == '':
|
||||||
|
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
||||||
|
if not max_data_loader_n_workers == '':
|
||||||
|
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
||||||
|
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
@ -708,6 +721,16 @@ def dreambooth_tab(
|
|||||||
],
|
],
|
||||||
value='75',
|
value='75',
|
||||||
)
|
)
|
||||||
|
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
||||||
|
# with gr.Row():
|
||||||
|
# max_train_epochs = gr.Textbox(
|
||||||
|
# label='Max train epoch',
|
||||||
|
# placeholder='(Optional) Override number of epoch',
|
||||||
|
# )
|
||||||
|
# max_data_loader_n_workers = gr.Textbox(
|
||||||
|
# label='Max num workers for DataLoader',
|
||||||
|
# placeholder='(Optional) Override number of epoch. Default: 8',
|
||||||
|
# )
|
||||||
with gr.Tab('Tools'):
|
with gr.Tab('Tools'):
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
'This section provide Dreambooth tools to help setup your dataset...'
|
'This section provide Dreambooth tools to help setup your dataset...'
|
||||||
@ -760,6 +783,8 @@ def dreambooth_tab(
|
|||||||
vae,
|
vae,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
688
fine_tune.py
688
fine_tune.py
@ -16,456 +16,326 @@ import library.train_util as train_util
|
|||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
return examples[0]
|
return examples[0]
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
|
||||||
cache_latents = args.cache_latents
|
cache_latents = args.cache_latents
|
||||||
|
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed) # 乱数系列を初期化する
|
set_seed(args.seed) # 乱数系列を初期化する
|
||||||
|
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
|
||||||
train_dataset = train_util.FineTuningDataset(
|
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||||
args.in_json,
|
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||||
args.train_batch_size,
|
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||||
args.train_data_dir,
|
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||||
tokenizer,
|
args.dataset_repeats, args.debug_dataset)
|
||||||
args.max_token_length,
|
train_dataset.make_buckets()
|
||||||
args.shuffle_caption,
|
|
||||||
args.keep_tokens,
|
|
||||||
args.resolution,
|
|
||||||
args.enable_bucket,
|
|
||||||
args.min_bucket_reso,
|
|
||||||
args.max_bucket_reso,
|
|
||||||
args.flip_aug,
|
|
||||||
args.color_aug,
|
|
||||||
args.face_crop_aug_range,
|
|
||||||
args.random_crop,
|
|
||||||
args.dataset_repeats,
|
|
||||||
args.debug_dataset,
|
|
||||||
)
|
|
||||||
train_dataset.make_buckets()
|
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
train_util.debug_dataset(train_dataset)
|
train_util.debug_dataset(train_dataset)
|
||||||
return
|
return
|
||||||
if len(train_dataset) == 0:
|
if len(train_dataset) == 0:
|
||||||
print(
|
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
|
||||||
'No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。'
|
return
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print('prepare accelerator')
|
print("prepare accelerator")
|
||||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
(
|
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
||||||
text_encoder,
|
|
||||||
vae,
|
|
||||||
unet,
|
|
||||||
load_stable_diffusion_format,
|
|
||||||
) = train_util.load_target_model(args, weight_dtype)
|
|
||||||
|
|
||||||
# verify load/save model formats
|
# verify load/save model formats
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
||||||
src_diffusers_model_path = None
|
src_diffusers_model_path = None
|
||||||
else:
|
else:
|
||||||
src_stable_diffusion_ckpt = None
|
src_stable_diffusion_ckpt = None
|
||||||
src_diffusers_model_path = args.pretrained_model_name_or_path
|
src_diffusers_model_path = args.pretrained_model_name_or_path
|
||||||
|
|
||||||
if args.save_model_as is None:
|
if args.save_model_as is None:
|
||||||
save_stable_diffusion_format = load_stable_diffusion_format
|
save_stable_diffusion_format = load_stable_diffusion_format
|
||||||
use_safetensors = args.use_safetensors
|
use_safetensors = args.use_safetensors
|
||||||
else:
|
else:
|
||||||
save_stable_diffusion_format = (
|
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
|
||||||
args.save_model_as.lower() == 'ckpt'
|
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||||
or args.save_model_as.lower() == 'safetensors'
|
|
||||||
)
|
|
||||||
use_safetensors = args.use_safetensors or (
|
|
||||||
'safetensors' in args.save_model_as.lower()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Diffusers版のxformers使用フラグを設定する関数
|
# Diffusers版のxformers使用フラグを設定する関数
|
||||||
def set_diffusers_xformers_flag(model, valid):
|
def set_diffusers_xformers_flag(model, valid):
|
||||||
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
|
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
|
||||||
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
|
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
|
||||||
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
|
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
|
||||||
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
|
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
|
||||||
|
|
||||||
# Recursively walk through all the children.
|
# Recursively walk through all the children.
|
||||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||||
# gets the message
|
# gets the message
|
||||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||||
if hasattr(module, 'set_use_memory_efficient_attention_xformers'):
|
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||||
module.set_use_memory_efficient_attention_xformers(valid)
|
module.set_use_memory_efficient_attention_xformers(valid)
|
||||||
|
|
||||||
for child in module.children():
|
for child in module.children():
|
||||||
fn_recursive_set_mem_eff(child)
|
fn_recursive_set_mem_eff(child)
|
||||||
|
|
||||||
fn_recursive_set_mem_eff(model)
|
fn_recursive_set_mem_eff(model)
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
print('Use xformers by Diffusers')
|
print("Use xformers by Diffusers")
|
||||||
set_diffusers_xformers_flag(unet, True)
|
set_diffusers_xformers_flag(unet, True)
|
||||||
else:
|
else:
|
||||||
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
|
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
|
||||||
print("Disable Diffusers' xformers")
|
print("Disable Diffusers' xformers")
|
||||||
set_diffusers_xformers_flag(unet, False)
|
set_diffusers_xformers_flag(unet, False)
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset.cache_latents(vae)
|
train_dataset.cache_latents(vae)
|
||||||
vae.to('cpu')
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# 学習を準備する:モデルを適切な状態にする
|
# 学習を準備する:モデルを適切な状態にする
|
||||||
training_models = []
|
training_models = []
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
unet.enable_gradient_checkpointing()
|
||||||
|
training_models.append(unet)
|
||||||
|
|
||||||
|
if args.train_text_encoder:
|
||||||
|
print("enable text encoder training")
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
training_models.append(unet)
|
training_models.append(text_encoder)
|
||||||
|
else:
|
||||||
if args.train_text_encoder:
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||||
print('enable text encoder training')
|
text_encoder.requires_grad_(False) # text encoderは学習しない
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
training_models.append(text_encoder)
|
text_encoder.train() # required for gradient_checkpointing
|
||||||
else:
|
else:
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
text_encoder.eval()
|
||||||
text_encoder.requires_grad_(False) # text encoderは学習しない
|
|
||||||
if args.gradient_checkpointing:
|
if not cache_latents:
|
||||||
text_encoder.gradient_checkpointing_enable()
|
vae.requires_grad_(False)
|
||||||
text_encoder.train() # required for gradient_checkpointing
|
vae.eval()
|
||||||
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
for m in training_models:
|
||||||
|
m.requires_grad_(True)
|
||||||
|
params = []
|
||||||
|
for m in training_models:
|
||||||
|
params.extend(m.parameters())
|
||||||
|
params_to_optimize = params
|
||||||
|
|
||||||
|
# 学習に必要なクラスを準備する
|
||||||
|
print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
|
# 8-bit Adamを使う
|
||||||
|
if args.use_8bit_adam:
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||||
|
print("use 8-bit Adam optimizer")
|
||||||
|
optimizer_class = bnb.optim.AdamW8bit
|
||||||
|
else:
|
||||||
|
optimizer_class = torch.optim.AdamW
|
||||||
|
|
||||||
|
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||||
|
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
||||||
|
|
||||||
|
# dataloaderを準備する
|
||||||
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||||
|
|
||||||
|
# 学習ステップ数を計算する
|
||||||
|
if args.max_train_epochs is not None:
|
||||||
|
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||||
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# lr schedulerを用意する
|
||||||
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
|
if args.full_fp16:
|
||||||
|
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||||
|
print("enable full fp16 training.")
|
||||||
|
unet.to(weight_dtype)
|
||||||
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
|
if args.train_text_encoder:
|
||||||
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
else:
|
||||||
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
|
if args.full_fp16:
|
||||||
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
|
# resumeする
|
||||||
|
if args.resume is not None:
|
||||||
|
print(f"resume training from state: {args.resume}")
|
||||||
|
accelerator.load_state(args.resume)
|
||||||
|
|
||||||
|
# epoch数を計算する
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
# 学習する
|
||||||
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
print("running training / 学習開始")
|
||||||
|
print(f" num examples / サンプル数: {train_dataset.num_train_images}")
|
||||||
|
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
|
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
|
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
|
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
|
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
||||||
|
num_train_timesteps=1000, clip_sample=False)
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
accelerator.init_trackers("finetuning")
|
||||||
|
|
||||||
|
for epoch in range(num_train_epochs):
|
||||||
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
for m in training_models:
|
||||||
|
m.train()
|
||||||
|
|
||||||
|
loss_total = 0
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
|
with torch.no_grad():
|
||||||
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
else:
|
||||||
|
# latentに変換
|
||||||
|
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||||
|
latents = latents * 0.18215
|
||||||
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(args.train_text_encoder):
|
||||||
|
# Get the text embedding for conditioning
|
||||||
|
input_ids = batch["input_ids"].to(accelerator.device)
|
||||||
|
encoder_hidden_states = train_util.get_hidden_states(
|
||||||
|
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype)
|
||||||
|
|
||||||
|
# Sample noise that we'll add to the latents
|
||||||
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
|
|
||||||
|
# Sample a random timestep for each image
|
||||||
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
|
# (this is the forward diffusion process)
|
||||||
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
|
if args.v_parameterization:
|
||||||
|
# v-parameterization training
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
else:
|
else:
|
||||||
text_encoder.eval()
|
target = noise
|
||||||
|
|
||||||
if not cache_latents:
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||||
vae.requires_grad_(False)
|
|
||||||
vae.eval()
|
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
for m in training_models:
|
accelerator.backward(loss)
|
||||||
m.requires_grad_(True)
|
if accelerator.sync_gradients:
|
||||||
params = []
|
params_to_clip = []
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
params.extend(m.parameters())
|
params_to_clip.extend(m.parameters())
|
||||||
params_to_optimize = params
|
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
optimizer.step()
|
||||||
print('prepare optimizer, data loader etc.')
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# 8-bit Adamを使う
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if args.use_8bit_adam:
|
if accelerator.sync_gradients:
|
||||||
try:
|
progress_bar.update(1)
|
||||||
import bitsandbytes as bnb
|
global_step += 1
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
'No bitsand bytes / bitsandbytesがインストールされていないようです'
|
|
||||||
)
|
|
||||||
print('use 8-bit Adam optimizer')
|
|
||||||
optimizer_class = bnb.optim.AdamW8bit
|
|
||||||
else:
|
|
||||||
optimizer_class = torch.optim.AdamW
|
|
||||||
|
|
||||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||||
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
# dataloaderを準備する
|
loss_total += current_loss
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
avr_loss = loss_total / (step+1)
|
||||||
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
progress_bar.set_postfix(**logs)
|
||||||
train_dataset,
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
num_workers=n_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
# lr schedulerを用意する
|
if global_step >= args.max_train_steps:
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
break
|
||||||
args.lr_scheduler,
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=args.lr_warmup_steps,
|
|
||||||
num_training_steps=args.max_train_steps
|
|
||||||
* args.gradient_accumulation_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
if args.logging_dir is not None:
|
||||||
if args.full_fp16:
|
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
||||||
assert (
|
accelerator.log(logs, step=epoch+1)
|
||||||
args.mixed_precision == 'fp16'
|
|
||||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
|
||||||
print('enable full fp16 training.')
|
|
||||||
unet.to(weight_dtype)
|
|
||||||
text_encoder.to(weight_dtype)
|
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
accelerator.wait_for_everyone()
|
||||||
if args.train_text_encoder:
|
|
||||||
(
|
|
||||||
unet,
|
|
||||||
text_encoder,
|
|
||||||
optimizer,
|
|
||||||
train_dataloader,
|
|
||||||
lr_scheduler,
|
|
||||||
) = accelerator.prepare(
|
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
||||||
unet, optimizer, train_dataloader, lr_scheduler
|
|
||||||
)
|
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
if args.save_every_n_epochs is not None:
|
||||||
if args.full_fp16:
|
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
|
||||||
|
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
|
||||||
|
|
||||||
# resumeする
|
is_main_process = accelerator.is_main_process
|
||||||
if args.resume is not None:
|
if is_main_process:
|
||||||
print(f'resume training from state: {args.resume}')
|
unet = unwrap_model(unet)
|
||||||
accelerator.load_state(args.resume)
|
text_encoder = unwrap_model(text_encoder)
|
||||||
|
|
||||||
# epoch数を計算する
|
accelerator.end_training()
|
||||||
num_update_steps_per_epoch = math.ceil(
|
|
||||||
len(train_dataloader) / args.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
num_train_epochs = math.ceil(
|
|
||||||
args.max_train_steps / num_update_steps_per_epoch
|
|
||||||
)
|
|
||||||
|
|
||||||
# 学習する
|
if args.save_state:
|
||||||
total_batch_size = (
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
args.train_batch_size
|
|
||||||
* accelerator.num_processes
|
|
||||||
* args.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
print('running training / 学習開始')
|
|
||||||
print(f' num examples / サンプル数: {train_dataset.num_train_images}')
|
|
||||||
print(f' num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}')
|
|
||||||
print(f' num epochs / epoch数: {num_train_epochs}')
|
|
||||||
print(f' batch size per device / バッチサイズ: {args.train_batch_size}')
|
|
||||||
print(
|
|
||||||
f' total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}'
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f' gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}'
|
|
||||||
)
|
|
||||||
print(f' total optimization steps / 学習ステップ数: {args.max_train_steps}')
|
|
||||||
|
|
||||||
progress_bar = tqdm(
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
range(args.max_train_steps),
|
|
||||||
smoothing=0,
|
|
||||||
disable=not accelerator.is_local_main_process,
|
|
||||||
desc='steps',
|
|
||||||
)
|
|
||||||
global_step = 0
|
|
||||||
|
|
||||||
noise_scheduler = DDPMScheduler(
|
if is_main_process:
|
||||||
beta_start=0.00085,
|
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||||
beta_end=0.012,
|
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors,
|
||||||
beta_schedule='scaled_linear',
|
save_dtype, epoch, global_step, text_encoder, unet, vae)
|
||||||
num_train_timesteps=1000,
|
print("model saved.")
|
||||||
clip_sample=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
accelerator.init_trackers('finetuning')
|
|
||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
|
||||||
print(f'epoch {epoch+1}/{num_train_epochs}')
|
|
||||||
for m in training_models:
|
|
||||||
m.train()
|
|
||||||
|
|
||||||
loss_total = 0
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
|
||||||
with accelerator.accumulate(
|
|
||||||
training_models[0]
|
|
||||||
): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
|
||||||
with torch.no_grad():
|
|
||||||
if 'latents' in batch and batch['latents'] is not None:
|
|
||||||
latents = batch['latents'].to(accelerator.device)
|
|
||||||
else:
|
|
||||||
# latentに変換
|
|
||||||
latents = vae.encode(
|
|
||||||
batch['images'].to(dtype=weight_dtype)
|
|
||||||
).latent_dist.sample()
|
|
||||||
latents = latents * 0.18215
|
|
||||||
b_size = latents.shape[0]
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(args.train_text_encoder):
|
|
||||||
# Get the text embedding for conditioning
|
|
||||||
input_ids = batch['input_ids'].to(accelerator.device)
|
|
||||||
encoder_hidden_states = train_util.get_hidden_states(
|
|
||||||
args,
|
|
||||||
input_ids,
|
|
||||||
tokenizer,
|
|
||||||
text_encoder,
|
|
||||||
None if not args.full_fp16 else weight_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sample noise that we'll add to the latents
|
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
|
||||||
timesteps = torch.randint(
|
|
||||||
0,
|
|
||||||
noise_scheduler.config.num_train_timesteps,
|
|
||||||
(b_size,),
|
|
||||||
device=latents.device,
|
|
||||||
)
|
|
||||||
timesteps = timesteps.long()
|
|
||||||
|
|
||||||
# Add noise to the latents according to the noise magnitude at each timestep
|
|
||||||
# (this is the forward diffusion process)
|
|
||||||
noisy_latents = noise_scheduler.add_noise(
|
|
||||||
latents, noise, timesteps
|
|
||||||
)
|
|
||||||
|
|
||||||
# Predict the noise residual
|
|
||||||
noise_pred = unet(
|
|
||||||
noisy_latents, timesteps, encoder_hidden_states
|
|
||||||
).sample
|
|
||||||
|
|
||||||
if args.v_parameterization:
|
|
||||||
# v-parameterization training
|
|
||||||
target = noise_scheduler.get_velocity(
|
|
||||||
latents, noise, timesteps
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
target = noise
|
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(
|
|
||||||
noise_pred.float(), target.float(), reduction='mean'
|
|
||||||
)
|
|
||||||
|
|
||||||
accelerator.backward(loss)
|
|
||||||
if accelerator.sync_gradients:
|
|
||||||
params_to_clip = []
|
|
||||||
for m in training_models:
|
|
||||||
params_to_clip.extend(m.parameters())
|
|
||||||
accelerator.clip_grad_norm_(
|
|
||||||
params_to_clip, 1.0
|
|
||||||
) # args.max_grad_norm)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
lr_scheduler.step()
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
|
||||||
if accelerator.sync_gradients:
|
|
||||||
progress_bar.update(1)
|
|
||||||
global_step += 1
|
|
||||||
|
|
||||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
|
||||||
if args.logging_dir is not None:
|
|
||||||
logs = {
|
|
||||||
'loss': current_loss,
|
|
||||||
'lr': lr_scheduler.get_last_lr()[0],
|
|
||||||
}
|
|
||||||
accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
loss_total += current_loss
|
|
||||||
avr_loss = loss_total / (step + 1)
|
|
||||||
logs = {'loss': avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
|
||||||
break
|
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
|
||||||
logs = {'epoch_loss': loss_total / len(train_dataloader)}
|
|
||||||
accelerator.log(logs, step=epoch + 1)
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
if args.save_every_n_epochs is not None:
|
|
||||||
src_path = (
|
|
||||||
src_stable_diffusion_ckpt
|
|
||||||
if save_stable_diffusion_format
|
|
||||||
else src_diffusers_model_path
|
|
||||||
)
|
|
||||||
train_util.save_sd_model_on_epoch_end(
|
|
||||||
args,
|
|
||||||
accelerator,
|
|
||||||
src_path,
|
|
||||||
save_stable_diffusion_format,
|
|
||||||
use_safetensors,
|
|
||||||
save_dtype,
|
|
||||||
epoch,
|
|
||||||
num_train_epochs,
|
|
||||||
global_step,
|
|
||||||
unwrap_model(text_encoder),
|
|
||||||
unwrap_model(unet),
|
|
||||||
vae,
|
|
||||||
)
|
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
|
||||||
if is_main_process:
|
|
||||||
unet = unwrap_model(unet)
|
|
||||||
text_encoder = unwrap_model(text_encoder)
|
|
||||||
|
|
||||||
accelerator.end_training()
|
|
||||||
|
|
||||||
if args.save_state:
|
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
|
||||||
|
|
||||||
del accelerator # この後メモリを使うのでこれは消す
|
|
||||||
|
|
||||||
if is_main_process:
|
|
||||||
src_path = (
|
|
||||||
src_stable_diffusion_ckpt
|
|
||||||
if save_stable_diffusion_format
|
|
||||||
else src_diffusers_model_path
|
|
||||||
)
|
|
||||||
train_util.save_sd_model_on_train_end(
|
|
||||||
args,
|
|
||||||
src_path,
|
|
||||||
save_stable_diffusion_format,
|
|
||||||
use_safetensors,
|
|
||||||
save_dtype,
|
|
||||||
epoch,
|
|
||||||
global_step,
|
|
||||||
text_encoder,
|
|
||||||
unet,
|
|
||||||
vae,
|
|
||||||
)
|
|
||||||
print('model saved.')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, False, True)
|
train_util.add_dataset_arguments(parser, False, True)
|
||||||
train_util.add_training_arguments(parser, False)
|
train_util.add_training_arguments(parser, False)
|
||||||
train_util.add_sd_saving_arguments(parser)
|
train_util.add_sd_saving_arguments(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||||
'--diffusers_xformers',
|
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
||||||
action='store_true',
|
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||||
help='use xformers by diffusers / Diffusersでxformersを使用する',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--train_text_encoder',
|
|
||||||
action='store_true',
|
|
||||||
help='train text encoder / text encoderも学習する',
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
@ -108,4 +108,4 @@ if __name__ == '__main__':
|
|||||||
if args.caption_extention is not None:
|
if args.caption_extention is not None:
|
||||||
args.caption_extension = args.caption_extention
|
args.caption_extension = args.caption_extention
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -1,789 +0,0 @@
|
|||||||
import gradio as gr
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import pathlib
|
|
||||||
import shutil
|
|
||||||
import argparse
|
|
||||||
from library.common_gui import (
|
|
||||||
get_folder_path,
|
|
||||||
get_file_path,
|
|
||||||
get_any_file_path,
|
|
||||||
get_saveasfile_path,
|
|
||||||
)
|
|
||||||
from library.utilities import utilities_tab
|
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
|
||||||
save_style_symbol = '\U0001f4be' # 💾
|
|
||||||
document_symbol = '\U0001F4C4' # 📄
|
|
||||||
|
|
||||||
|
|
||||||
def save_configuration(
|
|
||||||
save_as,
|
|
||||||
file_path,
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
v2,
|
|
||||||
v_parameterization,
|
|
||||||
train_dir,
|
|
||||||
image_folder,
|
|
||||||
output_dir,
|
|
||||||
logging_dir,
|
|
||||||
max_resolution,
|
|
||||||
min_bucket_reso,
|
|
||||||
max_bucket_reso,
|
|
||||||
batch_size,
|
|
||||||
flip_aug,
|
|
||||||
caption_metadata_filename,
|
|
||||||
latent_metadata_filename,
|
|
||||||
full_path,
|
|
||||||
learning_rate,
|
|
||||||
lr_scheduler,
|
|
||||||
lr_warmup,
|
|
||||||
dataset_repeats,
|
|
||||||
train_batch_size,
|
|
||||||
epoch,
|
|
||||||
save_every_n_epochs,
|
|
||||||
mixed_precision,
|
|
||||||
save_precision,
|
|
||||||
seed,
|
|
||||||
num_cpu_threads_per_process,
|
|
||||||
train_text_encoder,
|
|
||||||
create_caption,
|
|
||||||
create_buckets,
|
|
||||||
save_model_as,
|
|
||||||
caption_extension,
|
|
||||||
use_8bit_adam,
|
|
||||||
xformers,
|
|
||||||
clip_skip,
|
|
||||||
):
|
|
||||||
original_file_path = file_path
|
|
||||||
|
|
||||||
save_as_bool = True if save_as.get('label') == 'True' else False
|
|
||||||
|
|
||||||
if save_as_bool:
|
|
||||||
print('Save as...')
|
|
||||||
file_path = get_saveasfile_path(file_path)
|
|
||||||
else:
|
|
||||||
print('Save...')
|
|
||||||
if file_path == None or file_path == '':
|
|
||||||
file_path = get_saveasfile_path(file_path)
|
|
||||||
|
|
||||||
# print(file_path)
|
|
||||||
|
|
||||||
if file_path == None:
|
|
||||||
return original_file_path
|
|
||||||
|
|
||||||
# Return the values of the variables as a dictionary
|
|
||||||
variables = {
|
|
||||||
'pretrained_model_name_or_path': pretrained_model_name_or_path,
|
|
||||||
'v2': v2,
|
|
||||||
'v_parameterization': v_parameterization,
|
|
||||||
'train_dir': train_dir,
|
|
||||||
'image_folder': image_folder,
|
|
||||||
'output_dir': output_dir,
|
|
||||||
'logging_dir': logging_dir,
|
|
||||||
'max_resolution': max_resolution,
|
|
||||||
'min_bucket_reso': min_bucket_reso,
|
|
||||||
'max_bucket_reso': max_bucket_reso,
|
|
||||||
'batch_size': batch_size,
|
|
||||||
'flip_aug': flip_aug,
|
|
||||||
'caption_metadata_filename': caption_metadata_filename,
|
|
||||||
'latent_metadata_filename': latent_metadata_filename,
|
|
||||||
'full_path': full_path,
|
|
||||||
'learning_rate': learning_rate,
|
|
||||||
'lr_scheduler': lr_scheduler,
|
|
||||||
'lr_warmup': lr_warmup,
|
|
||||||
'dataset_repeats': dataset_repeats,
|
|
||||||
'train_batch_size': train_batch_size,
|
|
||||||
'epoch': epoch,
|
|
||||||
'save_every_n_epochs': save_every_n_epochs,
|
|
||||||
'mixed_precision': mixed_precision,
|
|
||||||
'save_precision': save_precision,
|
|
||||||
'seed': seed,
|
|
||||||
'num_cpu_threads_per_process': num_cpu_threads_per_process,
|
|
||||||
'train_text_encoder': train_text_encoder,
|
|
||||||
'create_buckets': create_buckets,
|
|
||||||
'create_caption': create_caption,
|
|
||||||
'save_model_as': save_model_as,
|
|
||||||
'caption_extension': caption_extension,
|
|
||||||
'use_8bit_adam': use_8bit_adam,
|
|
||||||
'xformers': xformers,
|
|
||||||
'clip_skip': clip_skip,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Save the data to the selected file
|
|
||||||
with open(file_path, 'w') as file:
|
|
||||||
json.dump(variables, file)
|
|
||||||
|
|
||||||
return file_path
|
|
||||||
|
|
||||||
|
|
||||||
def open_config_file(
|
|
||||||
file_path,
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
v2,
|
|
||||||
v_parameterization,
|
|
||||||
train_dir,
|
|
||||||
image_folder,
|
|
||||||
output_dir,
|
|
||||||
logging_dir,
|
|
||||||
max_resolution,
|
|
||||||
min_bucket_reso,
|
|
||||||
max_bucket_reso,
|
|
||||||
batch_size,
|
|
||||||
flip_aug,
|
|
||||||
caption_metadata_filename,
|
|
||||||
latent_metadata_filename,
|
|
||||||
full_path,
|
|
||||||
learning_rate,
|
|
||||||
lr_scheduler,
|
|
||||||
lr_warmup,
|
|
||||||
dataset_repeats,
|
|
||||||
train_batch_size,
|
|
||||||
epoch,
|
|
||||||
save_every_n_epochs,
|
|
||||||
mixed_precision,
|
|
||||||
save_precision,
|
|
||||||
seed,
|
|
||||||
num_cpu_threads_per_process,
|
|
||||||
train_text_encoder,
|
|
||||||
create_caption,
|
|
||||||
create_buckets,
|
|
||||||
save_model_as,
|
|
||||||
caption_extension,
|
|
||||||
use_8bit_adam,
|
|
||||||
xformers,
|
|
||||||
clip_skip,
|
|
||||||
):
|
|
||||||
original_file_path = file_path
|
|
||||||
file_path = get_file_path(file_path)
|
|
||||||
|
|
||||||
if file_path != '' and file_path != None:
|
|
||||||
print(file_path)
|
|
||||||
# load variables from JSON file
|
|
||||||
with open(file_path, 'r') as f:
|
|
||||||
my_data = json.load(f)
|
|
||||||
else:
|
|
||||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
|
||||||
my_data = {}
|
|
||||||
|
|
||||||
# Return the values of the variables as a dictionary
|
|
||||||
return (
|
|
||||||
file_path,
|
|
||||||
my_data.get(
|
|
||||||
'pretrained_model_name_or_path', pretrained_model_name_or_path
|
|
||||||
),
|
|
||||||
my_data.get('v2', v2),
|
|
||||||
my_data.get('v_parameterization', v_parameterization),
|
|
||||||
my_data.get('train_dir', train_dir),
|
|
||||||
my_data.get('image_folder', image_folder),
|
|
||||||
my_data.get('output_dir', output_dir),
|
|
||||||
my_data.get('logging_dir', logging_dir),
|
|
||||||
my_data.get('max_resolution', max_resolution),
|
|
||||||
my_data.get('min_bucket_reso', min_bucket_reso),
|
|
||||||
my_data.get('max_bucket_reso', max_bucket_reso),
|
|
||||||
my_data.get('batch_size', batch_size),
|
|
||||||
my_data.get('flip_aug', flip_aug),
|
|
||||||
my_data.get('caption_metadata_filename', caption_metadata_filename),
|
|
||||||
my_data.get('latent_metadata_filename', latent_metadata_filename),
|
|
||||||
my_data.get('full_path', full_path),
|
|
||||||
my_data.get('learning_rate', learning_rate),
|
|
||||||
my_data.get('lr_scheduler', lr_scheduler),
|
|
||||||
my_data.get('lr_warmup', lr_warmup),
|
|
||||||
my_data.get('dataset_repeats', dataset_repeats),
|
|
||||||
my_data.get('train_batch_size', train_batch_size),
|
|
||||||
my_data.get('epoch', epoch),
|
|
||||||
my_data.get('save_every_n_epochs', save_every_n_epochs),
|
|
||||||
my_data.get('mixed_precision', mixed_precision),
|
|
||||||
my_data.get('save_precision', save_precision),
|
|
||||||
my_data.get('seed', seed),
|
|
||||||
my_data.get(
|
|
||||||
'num_cpu_threads_per_process', num_cpu_threads_per_process
|
|
||||||
),
|
|
||||||
my_data.get('train_text_encoder', train_text_encoder),
|
|
||||||
my_data.get('create_buckets', create_buckets),
|
|
||||||
my_data.get('create_caption', create_caption),
|
|
||||||
my_data.get('save_model_as', save_model_as),
|
|
||||||
my_data.get('caption_extension', caption_extension),
|
|
||||||
my_data.get('use_8bit_adam', use_8bit_adam),
|
|
||||||
my_data.get('xformers', xformers),
|
|
||||||
my_data.get('clip_skip', clip_skip),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train_model(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
v2,
|
|
||||||
v_parameterization,
|
|
||||||
train_dir,
|
|
||||||
image_folder,
|
|
||||||
output_dir,
|
|
||||||
logging_dir,
|
|
||||||
max_resolution,
|
|
||||||
min_bucket_reso,
|
|
||||||
max_bucket_reso,
|
|
||||||
batch_size,
|
|
||||||
flip_aug,
|
|
||||||
caption_metadata_filename,
|
|
||||||
latent_metadata_filename,
|
|
||||||
full_path,
|
|
||||||
learning_rate,
|
|
||||||
lr_scheduler,
|
|
||||||
lr_warmup,
|
|
||||||
dataset_repeats,
|
|
||||||
train_batch_size,
|
|
||||||
epoch,
|
|
||||||
save_every_n_epochs,
|
|
||||||
mixed_precision,
|
|
||||||
save_precision,
|
|
||||||
seed,
|
|
||||||
num_cpu_threads_per_process,
|
|
||||||
train_text_encoder,
|
|
||||||
generate_caption_database,
|
|
||||||
generate_image_buckets,
|
|
||||||
save_model_as,
|
|
||||||
caption_extension,
|
|
||||||
use_8bit_adam,
|
|
||||||
xformers,
|
|
||||||
clip_skip,
|
|
||||||
):
|
|
||||||
def save_inference_file(output_dir, v2, v_parameterization):
|
|
||||||
# Copy inference model for v2 if required
|
|
||||||
if v2 and v_parameterization:
|
|
||||||
print(f'Saving v2-inference-v.yaml as {output_dir}/last.yaml')
|
|
||||||
shutil.copy(
|
|
||||||
f'./v2_inference/v2-inference-v.yaml',
|
|
||||||
f'{output_dir}/last.yaml',
|
|
||||||
)
|
|
||||||
elif v2:
|
|
||||||
print(f'Saving v2-inference.yaml as {output_dir}/last.yaml')
|
|
||||||
shutil.copy(
|
|
||||||
f'./v2_inference/v2-inference.yaml',
|
|
||||||
f'{output_dir}/last.yaml',
|
|
||||||
)
|
|
||||||
|
|
||||||
# create caption json file
|
|
||||||
if generate_caption_database:
|
|
||||||
if not os.path.exists(train_dir):
|
|
||||||
os.mkdir(train_dir)
|
|
||||||
|
|
||||||
for root, dirs, files in os.walk(image_folder):
|
|
||||||
for dir in dirs:
|
|
||||||
print(os.path.join(root, dir))
|
|
||||||
|
|
||||||
run_cmd = (
|
|
||||||
f'./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py'
|
|
||||||
)
|
|
||||||
if caption_extension == '':
|
|
||||||
run_cmd += f' --caption_extension=".txt"'
|
|
||||||
else:
|
|
||||||
run_cmd += f' --caption_extension={caption_extension}'
|
|
||||||
run_cmd += f' "{os.path.join(root, dir)}"'
|
|
||||||
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
|
|
||||||
if full_path:
|
|
||||||
run_cmd += f' --full_path'
|
|
||||||
|
|
||||||
print(run_cmd)
|
|
||||||
|
|
||||||
# Run the command
|
|
||||||
subprocess.run(run_cmd)
|
|
||||||
|
|
||||||
# create images buckets
|
|
||||||
if generate_image_buckets:
|
|
||||||
run_cmd = (
|
|
||||||
f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py'
|
|
||||||
)
|
|
||||||
run_cmd += f' "crap"'
|
|
||||||
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
|
|
||||||
run_cmd += f' "{train_dir}/{latent_metadata_filename}"'
|
|
||||||
run_cmd += f' "{pretrained_model_name_or_path}"'
|
|
||||||
run_cmd += f' --batch_size={batch_size}'
|
|
||||||
run_cmd += f' --max_resolution={max_resolution}'
|
|
||||||
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
|
|
||||||
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
|
|
||||||
run_cmd += f' --mixed_precision={mixed_precision}'
|
|
||||||
if flip_aug:
|
|
||||||
run_cmd += f' --flip_aug'
|
|
||||||
if full_path:
|
|
||||||
run_cmd += f' --full_path'
|
|
||||||
|
|
||||||
print(run_cmd)
|
|
||||||
|
|
||||||
# Run the command
|
|
||||||
subprocess.run(run_cmd)
|
|
||||||
|
|
||||||
image_num = 0
|
|
||||||
for root, dirs, files in os.walk(image_folder):
|
|
||||||
for dir in dirs:
|
|
||||||
image_num += len(
|
|
||||||
[f for f in os.listdir(os.path.join(root, dir)) if f.endswith('.npz')]
|
|
||||||
)
|
|
||||||
print(f'image_num = {image_num}')
|
|
||||||
|
|
||||||
repeats = int(image_num) * int(dataset_repeats)
|
|
||||||
print(f'repeats = {str(repeats)}')
|
|
||||||
|
|
||||||
# calculate max_train_steps
|
|
||||||
max_train_steps = int(
|
|
||||||
math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Divide by two because flip augmentation create two copied of the source images
|
|
||||||
if flip_aug:
|
|
||||||
max_train_steps = int(math.ceil(float(max_train_steps) / 2))
|
|
||||||
|
|
||||||
print(f'max_train_steps = {max_train_steps}')
|
|
||||||
|
|
||||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
|
||||||
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
|
||||||
|
|
||||||
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
|
|
||||||
if v2:
|
|
||||||
run_cmd += ' --v2'
|
|
||||||
if v_parameterization:
|
|
||||||
run_cmd += ' --v_parameterization'
|
|
||||||
if train_text_encoder:
|
|
||||||
run_cmd += ' --train_text_encoder'
|
|
||||||
if use_8bit_adam:
|
|
||||||
run_cmd += f' --use_8bit_adam'
|
|
||||||
if xformers:
|
|
||||||
run_cmd += f' --xformers'
|
|
||||||
run_cmd += (
|
|
||||||
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
|
||||||
)
|
|
||||||
run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"'
|
|
||||||
run_cmd += f' --train_data_dir="{image_folder}"'
|
|
||||||
run_cmd += f' --output_dir="{output_dir}"'
|
|
||||||
if not logging_dir == '':
|
|
||||||
run_cmd += f' --logging_dir="{logging_dir}"'
|
|
||||||
run_cmd += f' --train_batch_size={train_batch_size}'
|
|
||||||
run_cmd += f' --dataset_repeats={dataset_repeats}'
|
|
||||||
run_cmd += f' --learning_rate={learning_rate}'
|
|
||||||
run_cmd += f' --lr_scheduler={lr_scheduler}'
|
|
||||||
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
|
|
||||||
run_cmd += f' --max_train_steps={max_train_steps}'
|
|
||||||
run_cmd += f' --mixed_precision={mixed_precision}'
|
|
||||||
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
|
|
||||||
run_cmd += f' --seed={seed}'
|
|
||||||
run_cmd += f' --save_precision={save_precision}'
|
|
||||||
if not save_model_as == 'same as source model':
|
|
||||||
run_cmd += f' --save_model_as={save_model_as}'
|
|
||||||
if int(clip_skip) > 1:
|
|
||||||
run_cmd += f' --clip_skip={str(clip_skip)}'
|
|
||||||
|
|
||||||
print(run_cmd)
|
|
||||||
# Run the command
|
|
||||||
subprocess.run(run_cmd)
|
|
||||||
|
|
||||||
# check if output_dir/last is a folder... therefore it is a diffuser model
|
|
||||||
last_dir = pathlib.Path(f'{output_dir}/last')
|
|
||||||
|
|
||||||
if not last_dir.is_dir():
|
|
||||||
# Copy inference model for v2 if required
|
|
||||||
save_inference_file(output_dir, v2, v_parameterization)
|
|
||||||
|
|
||||||
|
|
||||||
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
|
||||||
# define a list of substrings to search for
|
|
||||||
substrings_v2 = [
|
|
||||||
'stabilityai/stable-diffusion-2-1-base',
|
|
||||||
'stabilityai/stable-diffusion-2-base',
|
|
||||||
]
|
|
||||||
|
|
||||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
|
||||||
if str(value) in substrings_v2:
|
|
||||||
print('SD v2 model detected. Setting --v2 parameter')
|
|
||||||
v2 = True
|
|
||||||
v_parameterization = False
|
|
||||||
|
|
||||||
return value, v2, v_parameterization
|
|
||||||
|
|
||||||
# define a list of substrings to search for v-objective
|
|
||||||
substrings_v_parameterization = [
|
|
||||||
'stabilityai/stable-diffusion-2-1',
|
|
||||||
'stabilityai/stable-diffusion-2',
|
|
||||||
]
|
|
||||||
|
|
||||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
|
||||||
if str(value) in substrings_v_parameterization:
|
|
||||||
print(
|
|
||||||
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
|
|
||||||
)
|
|
||||||
v2 = True
|
|
||||||
v_parameterization = True
|
|
||||||
|
|
||||||
return value, v2, v_parameterization
|
|
||||||
|
|
||||||
# define a list of substrings to v1.x
|
|
||||||
substrings_v1_model = [
|
|
||||||
'CompVis/stable-diffusion-v1-4',
|
|
||||||
'runwayml/stable-diffusion-v1-5',
|
|
||||||
]
|
|
||||||
|
|
||||||
if str(value) in substrings_v1_model:
|
|
||||||
v2 = False
|
|
||||||
v_parameterization = False
|
|
||||||
|
|
||||||
return value, v2, v_parameterization
|
|
||||||
|
|
||||||
if value == 'custom':
|
|
||||||
value = ''
|
|
||||||
v2 = False
|
|
||||||
v_parameterization = False
|
|
||||||
|
|
||||||
return value, v2, v_parameterization
|
|
||||||
|
|
||||||
|
|
||||||
def remove_doublequote(file_path):
|
|
||||||
if file_path != None:
|
|
||||||
file_path = file_path.replace('"', '')
|
|
||||||
|
|
||||||
return file_path
|
|
||||||
|
|
||||||
|
|
||||||
def UI(username, password):
|
|
||||||
|
|
||||||
css = ''
|
|
||||||
|
|
||||||
if os.path.exists('./style.css'):
|
|
||||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
|
||||||
print('Load CSS...')
|
|
||||||
css += file.read() + '\n'
|
|
||||||
|
|
||||||
interface = gr.Blocks(css=css)
|
|
||||||
|
|
||||||
with interface:
|
|
||||||
with gr.Tab('Finetune'):
|
|
||||||
finetune_tab()
|
|
||||||
with gr.Tab('Utilities'):
|
|
||||||
utilities_tab(enable_dreambooth_tab=False)
|
|
||||||
|
|
||||||
# Show the interface
|
|
||||||
if not username == '':
|
|
||||||
interface.launch(auth=(username, password))
|
|
||||||
else:
|
|
||||||
interface.launch()
|
|
||||||
|
|
||||||
|
|
||||||
def finetune_tab():
|
|
||||||
dummy_ft_true = gr.Label(value=True, visible=False)
|
|
||||||
dummy_ft_false = gr.Label(value=False, visible=False)
|
|
||||||
gr.Markdown('Train a custom model using kohya finetune python code...')
|
|
||||||
with gr.Accordion('Configuration file', open=False):
|
|
||||||
with gr.Row():
|
|
||||||
button_open_config = gr.Button(
|
|
||||||
f'Open {folder_symbol}', elem_id='open_folder'
|
|
||||||
)
|
|
||||||
button_save_config = gr.Button(
|
|
||||||
f'Save {save_style_symbol}', elem_id='open_folder'
|
|
||||||
)
|
|
||||||
button_save_as_config = gr.Button(
|
|
||||||
f'Save as... {save_style_symbol}',
|
|
||||||
elem_id='open_folder',
|
|
||||||
)
|
|
||||||
config_file_name = gr.Textbox(
|
|
||||||
label='', placeholder='type file path or use buttons...'
|
|
||||||
)
|
|
||||||
config_file_name.change(
|
|
||||||
remove_doublequote,
|
|
||||||
inputs=[config_file_name],
|
|
||||||
outputs=[config_file_name],
|
|
||||||
)
|
|
||||||
with gr.Tab('Source model'):
|
|
||||||
# Define the input elements
|
|
||||||
with gr.Row():
|
|
||||||
pretrained_model_name_or_path_input = gr.Textbox(
|
|
||||||
label='Pretrained model name or path',
|
|
||||||
placeholder='enter the path to custom model or name of pretrained model',
|
|
||||||
)
|
|
||||||
pretrained_model_name_or_path_file = gr.Button(
|
|
||||||
document_symbol, elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
pretrained_model_name_or_path_file.click(
|
|
||||||
get_any_file_path,
|
|
||||||
inputs=pretrained_model_name_or_path_input,
|
|
||||||
outputs=pretrained_model_name_or_path_input,
|
|
||||||
)
|
|
||||||
pretrained_model_name_or_path_folder = gr.Button(
|
|
||||||
folder_symbol, elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
pretrained_model_name_or_path_folder.click(
|
|
||||||
get_folder_path,
|
|
||||||
inputs=pretrained_model_name_or_path_input,
|
|
||||||
outputs=pretrained_model_name_or_path_input,
|
|
||||||
)
|
|
||||||
model_list = gr.Dropdown(
|
|
||||||
label='(Optional) Model Quick Pick',
|
|
||||||
choices=[
|
|
||||||
'custom',
|
|
||||||
'stabilityai/stable-diffusion-2-1-base',
|
|
||||||
'stabilityai/stable-diffusion-2-base',
|
|
||||||
'stabilityai/stable-diffusion-2-1',
|
|
||||||
'stabilityai/stable-diffusion-2',
|
|
||||||
'runwayml/stable-diffusion-v1-5',
|
|
||||||
'CompVis/stable-diffusion-v1-4',
|
|
||||||
],
|
|
||||||
)
|
|
||||||
save_model_as_dropdown = gr.Dropdown(
|
|
||||||
label='Save trained model as',
|
|
||||||
choices=[
|
|
||||||
'same as source model',
|
|
||||||
'ckpt',
|
|
||||||
'diffusers',
|
|
||||||
'diffusers_safetensors',
|
|
||||||
'safetensors',
|
|
||||||
],
|
|
||||||
value='same as source model',
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
v2_input = gr.Checkbox(label='v2', value=True)
|
|
||||||
v_parameterization_input = gr.Checkbox(
|
|
||||||
label='v_parameterization', value=False
|
|
||||||
)
|
|
||||||
model_list.change(
|
|
||||||
set_pretrained_model_name_or_path_input,
|
|
||||||
inputs=[model_list, v2_input, v_parameterization_input],
|
|
||||||
outputs=[
|
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
with gr.Tab('Folders'):
|
|
||||||
with gr.Row():
|
|
||||||
train_dir_input = gr.Textbox(
|
|
||||||
label='Training config folder',
|
|
||||||
placeholder='folder where the training configuration files will be saved',
|
|
||||||
)
|
|
||||||
train_dir_folder = gr.Button(
|
|
||||||
folder_symbol, elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
train_dir_folder.click(get_folder_path, outputs=train_dir_input)
|
|
||||||
|
|
||||||
image_folder_input = gr.Textbox(
|
|
||||||
label='Training Image folder',
|
|
||||||
placeholder='folder where the training images are located',
|
|
||||||
)
|
|
||||||
image_folder_input_folder = gr.Button(
|
|
||||||
folder_symbol, elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
image_folder_input_folder.click(
|
|
||||||
get_folder_path, outputs=image_folder_input
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
output_dir_input = gr.Textbox(
|
|
||||||
label='Output folder',
|
|
||||||
placeholder='folder where the model will be saved',
|
|
||||||
)
|
|
||||||
output_dir_input_folder = gr.Button(
|
|
||||||
folder_symbol, elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
output_dir_input_folder.click(
|
|
||||||
get_folder_path, outputs=output_dir_input
|
|
||||||
)
|
|
||||||
|
|
||||||
logging_dir_input = gr.Textbox(
|
|
||||||
label='Logging folder',
|
|
||||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
|
||||||
)
|
|
||||||
logging_dir_input_folder = gr.Button(
|
|
||||||
folder_symbol, elem_id='open_folder_small'
|
|
||||||
)
|
|
||||||
logging_dir_input_folder.click(
|
|
||||||
get_folder_path, outputs=logging_dir_input
|
|
||||||
)
|
|
||||||
train_dir_input.change(
|
|
||||||
remove_doublequote,
|
|
||||||
inputs=[train_dir_input],
|
|
||||||
outputs=[train_dir_input],
|
|
||||||
)
|
|
||||||
image_folder_input.change(
|
|
||||||
remove_doublequote,
|
|
||||||
inputs=[image_folder_input],
|
|
||||||
outputs=[image_folder_input],
|
|
||||||
)
|
|
||||||
output_dir_input.change(
|
|
||||||
remove_doublequote,
|
|
||||||
inputs=[output_dir_input],
|
|
||||||
outputs=[output_dir_input],
|
|
||||||
)
|
|
||||||
with gr.Tab('Dataset preparation'):
|
|
||||||
with gr.Row():
|
|
||||||
max_resolution_input = gr.Textbox(
|
|
||||||
label='Resolution (width,height)', value='512,512'
|
|
||||||
)
|
|
||||||
min_bucket_reso = gr.Textbox(
|
|
||||||
label='Min bucket resolution', value='256'
|
|
||||||
)
|
|
||||||
max_bucket_reso = gr.Textbox(
|
|
||||||
label='Max bucket resolution', value='1024'
|
|
||||||
)
|
|
||||||
batch_size = gr.Textbox(label='Batch size', value='1')
|
|
||||||
with gr.Accordion('Advanced parameters', open=False):
|
|
||||||
with gr.Row():
|
|
||||||
caption_metadata_filename = gr.Textbox(
|
|
||||||
label='Caption metadata filename', value='meta_cap.json'
|
|
||||||
)
|
|
||||||
latent_metadata_filename = gr.Textbox(
|
|
||||||
label='Latent metadata filename', value='meta_lat.json'
|
|
||||||
)
|
|
||||||
full_path = gr.Checkbox(label='Use full path', value=True)
|
|
||||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
|
||||||
with gr.Tab('Training parameters'):
|
|
||||||
with gr.Row():
|
|
||||||
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
|
||||||
lr_scheduler_input = gr.Dropdown(
|
|
||||||
label='LR Scheduler',
|
|
||||||
choices=[
|
|
||||||
'constant',
|
|
||||||
'constant_with_warmup',
|
|
||||||
'cosine',
|
|
||||||
'cosine_with_restarts',
|
|
||||||
'linear',
|
|
||||||
'polynomial',
|
|
||||||
],
|
|
||||||
value='constant',
|
|
||||||
)
|
|
||||||
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
|
|
||||||
with gr.Row():
|
|
||||||
dataset_repeats_input = gr.Textbox(
|
|
||||||
label='Dataset repeats', value=40
|
|
||||||
)
|
|
||||||
train_batch_size_input = gr.Slider(
|
|
||||||
minimum=1,
|
|
||||||
maximum=32,
|
|
||||||
label='Train batch size',
|
|
||||||
value=1,
|
|
||||||
step=1,
|
|
||||||
)
|
|
||||||
epoch_input = gr.Textbox(label='Epoch', value=1)
|
|
||||||
save_every_n_epochs_input = gr.Textbox(
|
|
||||||
label='Save every N epochs', value=1
|
|
||||||
)
|
|
||||||
with gr.Row():
|
|
||||||
mixed_precision_input = gr.Dropdown(
|
|
||||||
label='Mixed precision',
|
|
||||||
choices=[
|
|
||||||
'no',
|
|
||||||
'fp16',
|
|
||||||
'bf16',
|
|
||||||
],
|
|
||||||
value='fp16',
|
|
||||||
)
|
|
||||||
save_precision_input = gr.Dropdown(
|
|
||||||
label='Save precision',
|
|
||||||
choices=[
|
|
||||||
'float',
|
|
||||||
'fp16',
|
|
||||||
'bf16',
|
|
||||||
],
|
|
||||||
value='fp16',
|
|
||||||
)
|
|
||||||
num_cpu_threads_per_process_input = gr.Slider(
|
|
||||||
minimum=1,
|
|
||||||
maximum=os.cpu_count(),
|
|
||||||
step=1,
|
|
||||||
label='Number of CPU threads per process',
|
|
||||||
value=os.cpu_count(),
|
|
||||||
)
|
|
||||||
seed_input = gr.Textbox(label='Seed', value=1234)
|
|
||||||
with gr.Row():
|
|
||||||
caption_extention_input = gr.Textbox(
|
|
||||||
label='Caption Extension',
|
|
||||||
placeholder='(Optional) Extension for caption files. default: .txt',
|
|
||||||
)
|
|
||||||
train_text_encoder_input = gr.Checkbox(
|
|
||||||
label='Train text encoder', value=True
|
|
||||||
)
|
|
||||||
with gr.Accordion('Advanced parameters', open=False):
|
|
||||||
with gr.Row():
|
|
||||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
|
||||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
|
||||||
clip_skip = gr.Slider(
|
|
||||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
|
||||||
)
|
|
||||||
with gr.Box():
|
|
||||||
with gr.Row():
|
|
||||||
create_caption = gr.Checkbox(
|
|
||||||
label='Generate caption metadata', value=True
|
|
||||||
)
|
|
||||||
create_buckets = gr.Checkbox(
|
|
||||||
label='Generate image buckets metadata', value=True
|
|
||||||
)
|
|
||||||
|
|
||||||
button_run = gr.Button('Train model')
|
|
||||||
|
|
||||||
settings_list = [
|
|
||||||
pretrained_model_name_or_path_input,
|
|
||||||
v2_input,
|
|
||||||
v_parameterization_input,
|
|
||||||
train_dir_input,
|
|
||||||
image_folder_input,
|
|
||||||
output_dir_input,
|
|
||||||
logging_dir_input,
|
|
||||||
max_resolution_input,
|
|
||||||
min_bucket_reso,
|
|
||||||
max_bucket_reso,
|
|
||||||
batch_size,
|
|
||||||
flip_aug,
|
|
||||||
caption_metadata_filename,
|
|
||||||
latent_metadata_filename,
|
|
||||||
full_path,
|
|
||||||
learning_rate_input,
|
|
||||||
lr_scheduler_input,
|
|
||||||
lr_warmup_input,
|
|
||||||
dataset_repeats_input,
|
|
||||||
train_batch_size_input,
|
|
||||||
epoch_input,
|
|
||||||
save_every_n_epochs_input,
|
|
||||||
mixed_precision_input,
|
|
||||||
save_precision_input,
|
|
||||||
seed_input,
|
|
||||||
num_cpu_threads_per_process_input,
|
|
||||||
train_text_encoder_input,
|
|
||||||
create_caption,
|
|
||||||
create_buckets,
|
|
||||||
save_model_as_dropdown,
|
|
||||||
caption_extention_input,
|
|
||||||
use_8bit_adam,
|
|
||||||
xformers,
|
|
||||||
clip_skip,
|
|
||||||
]
|
|
||||||
|
|
||||||
button_run.click(train_model, inputs=settings_list)
|
|
||||||
|
|
||||||
button_open_config.click(
|
|
||||||
open_config_file,
|
|
||||||
inputs=[config_file_name] + settings_list,
|
|
||||||
outputs=[config_file_name] + settings_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
button_save_config.click(
|
|
||||||
save_configuration,
|
|
||||||
inputs=[dummy_ft_false, config_file_name] + settings_list,
|
|
||||||
outputs=[config_file_name],
|
|
||||||
)
|
|
||||||
|
|
||||||
button_save_as_config.click(
|
|
||||||
save_configuration,
|
|
||||||
inputs=[dummy_ft_true, config_file_name] + settings_list,
|
|
||||||
outputs=[config_file_name],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
'--username', type=str, default='', help='Username for authentication'
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--password', type=str, default='', help='Password for authentication'
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
UI(username=args.username, password=args.password)
|
|
@ -13,6 +13,7 @@ from library.common_gui import (
|
|||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
save_inference_file,
|
save_inference_file,
|
||||||
set_pretrained_model_name_or_path_input,
|
set_pretrained_model_name_or_path_input,
|
||||||
|
gradio_advanced_training,run_cmd_advanced_training
|
||||||
)
|
)
|
||||||
from library.utilities import utilities_tab
|
from library.utilities import utilities_tab
|
||||||
|
|
||||||
@ -67,6 +68,8 @@ def save_configuration(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -150,6 +153,8 @@ def open_config_file(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -218,6 +223,8 @@ def train_model(
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
):
|
):
|
||||||
# create caption json file
|
# create caption json file
|
||||||
if generate_caption_database:
|
if generate_caption_database:
|
||||||
@ -336,6 +343,7 @@ def train_model(
|
|||||||
run_cmd += f' --output_name="{output_name}"'
|
run_cmd += f' --output_name="{output_name}"'
|
||||||
if (int(max_token_length) > 75):
|
if (int(max_token_length) > 75):
|
||||||
run_cmd += f' --max_token_length={max_token_length}'
|
run_cmd += f' --max_token_length={max_token_length}'
|
||||||
|
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
@ -656,6 +664,7 @@ def finetune_tab():
|
|||||||
],
|
],
|
||||||
value='75',
|
value='75',
|
||||||
)
|
)
|
||||||
|
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
create_caption = gr.Checkbox(
|
create_caption = gr.Checkbox(
|
||||||
@ -710,6 +719,8 @@ def finetune_tab():
|
|||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
output_name,
|
output_name,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_run.click(train_model, inputs=settings_list)
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
@ -1,38 +1,3 @@
|
|||||||
# txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc...
|
|
||||||
|
|
||||||
# v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix
|
|
||||||
# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working
|
|
||||||
# v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode
|
|
||||||
# v5: fix clip_sample=True for scheduler, add VGG guidance
|
|
||||||
# v6: refactor to use model util, load VAE without vae folder, support safe tensors
|
|
||||||
# v7: add use_original_file_name and iter_same_seed option, change vgg16 guide input image size,
|
|
||||||
# Diffusers 0.10.0 (support new schedulers (dpm_2, dpm_2_a, heun, dpmsingle), supports all scheduler in v-prediction)
|
|
||||||
# v8: accept wildcard for ckpt name (when only one file is matched), fix a bug app crushes because PIL image doesn't have filename attr sometimes,
|
|
||||||
# v9: sort file names, fix an issue in img2img when prompt from metadata with images_per_prompt>1
|
|
||||||
# v10: fix app crashes when different image size in prompts
|
|
||||||
|
|
||||||
# Copyright 2022 kohya_ss @kohya_ss
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# license of included scripts:
|
|
||||||
|
|
||||||
# FlashAttention: based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
|
||||||
# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
|
||||||
|
|
||||||
# Diffusers (model conversion, CLIP guided stable diffusion, schedulers etc.):
|
|
||||||
# ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
VGG(
|
VGG(
|
||||||
(features): Sequential(
|
(features): Sequential(
|
||||||
@ -81,11 +46,13 @@ VGG(
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import time
|
import time
|
||||||
|
import zipfile
|
||||||
from diffusers.utils import deprecate
|
from diffusers.utils import deprecate
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
import argparse
|
import argparse
|
||||||
@ -517,7 +484,7 @@ class PipelineLike():
|
|||||||
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
||||||
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
||||||
|
|
||||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||||
def enable_xformers_memory_efficient_attention(self):
|
def enable_xformers_memory_efficient_attention(self):
|
||||||
r"""
|
r"""
|
||||||
Enable memory efficient attention as implemented in xformers.
|
Enable memory efficient attention as implemented in xformers.
|
||||||
@ -590,6 +557,7 @@ class PipelineLike():
|
|||||||
width: int = 512,
|
width: int = 512,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float = 7.5,
|
||||||
|
negative_scale: float = None,
|
||||||
strength: float = 0.8,
|
strength: float = 0.8,
|
||||||
# num_images_per_prompt: Optional[int] = 1,
|
# num_images_per_prompt: Optional[int] = 1,
|
||||||
eta: float = 0.0,
|
eta: float = 0.0,
|
||||||
@ -708,6 +676,11 @@ class PipelineLike():
|
|||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
if not do_classifier_free_guidance and negative_scale is not None:
|
||||||
|
print(f"negative_scale is ignored if guidance scalle <= 1.0")
|
||||||
|
negative_scale = None
|
||||||
|
|
||||||
# get unconditional embeddings for classifier free guidance
|
# get unconditional embeddings for classifier free guidance
|
||||||
if negative_prompt is None:
|
if negative_prompt is None:
|
||||||
negative_prompt = [""] * batch_size
|
negative_prompt = [""] * batch_size
|
||||||
@ -729,8 +702,21 @@ class PipelineLike():
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if negative_scale is not None:
|
||||||
|
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
|
||||||
|
pipe=self,
|
||||||
|
prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須
|
||||||
|
uncond_prompt=[""]*batch_size,
|
||||||
|
max_embeddings_multiples=max_embeddings_multiples,
|
||||||
|
clip_skip=self.clip_skip,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
if negative_scale is None:
|
||||||
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||||
|
else:
|
||||||
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
|
||||||
|
|
||||||
# CLIP guidanceで使用するembeddingsを取得する
|
# CLIP guidanceで使用するembeddingsを取得する
|
||||||
if self.clip_guidance_scale > 0:
|
if self.clip_guidance_scale > 0:
|
||||||
@ -861,22 +847,28 @@ class PipelineLike():
|
|||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_step_kwargs["eta"] = eta
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = latents.repeat((2, 1, 1, 1)) if do_classifier_free_guidance else latents
|
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
if negative_scale is None:
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
else:
|
||||||
|
noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(num_latent_input) # uncond is real uncond
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * \
|
||||||
|
(noise_pred_text - noise_pred_uncond) - negative_scale * (noise_pred_negative - noise_pred_uncond)
|
||||||
|
|
||||||
# perform clip guidance
|
# perform clip guidance
|
||||||
if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0:
|
if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0:
|
||||||
text_embeddings_for_guidance = (text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings)
|
text_embeddings_for_guidance = (text_embeddings.chunk(num_latent_input)[
|
||||||
|
1] if do_classifier_free_guidance else text_embeddings)
|
||||||
|
|
||||||
if self.clip_guidance_scale > 0:
|
if self.clip_guidance_scale > 0:
|
||||||
noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred,
|
noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred,
|
||||||
@ -1982,26 +1974,50 @@ def main(args):
|
|||||||
vgg16_model.to(dtype).to(device)
|
vgg16_model.to(dtype).to(device)
|
||||||
|
|
||||||
# networkを組み込む
|
# networkを組み込む
|
||||||
if args.network_module is not None:
|
if args.network_module:
|
||||||
# assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません"
|
networks = []
|
||||||
|
for i, network_module in enumerate(args.network_module):
|
||||||
|
print("import network module:", network_module)
|
||||||
|
imported_module = importlib.import_module(network_module)
|
||||||
|
|
||||||
print("import network module:", args.network_module)
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||||
network_module = importlib.import_module(args.network_module)
|
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
|
||||||
|
|
||||||
network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs)
|
net_kwargs = {}
|
||||||
if network is None:
|
if args.network_args and i < len(args.network_args):
|
||||||
return
|
network_args = args.network_args[i]
|
||||||
|
# TODO escape special chars
|
||||||
|
network_args = network_args.split(";")
|
||||||
|
for net_arg in network_args:
|
||||||
|
key, value = net_arg.split("=")
|
||||||
|
net_kwargs[key] = value
|
||||||
|
|
||||||
print("load network weights from:", args.network_weights)
|
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
|
||||||
network.load_weights(args.network_weights)
|
if network is None:
|
||||||
|
return
|
||||||
|
|
||||||
network.apply_to(text_encoder, unet)
|
if args.network_weights and i < len(args.network_weights):
|
||||||
|
network_weight = args.network_weights[i]
|
||||||
|
print("load network weights from:", network_weight)
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if os.path.splitext(network_weight)[1] == '.safetensors':
|
||||||
network.to(memory_format=torch.channels_last)
|
from safetensors.torch import safe_open
|
||||||
network.to(dtype).to(device)
|
with safe_open(network_weight, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
if metadata is not None:
|
||||||
|
print(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
|
network.load_weights(network_weight)
|
||||||
|
|
||||||
|
network.apply_to(text_encoder, unet)
|
||||||
|
|
||||||
|
if args.opt_channels_last:
|
||||||
|
network.to(memory_format=torch.channels_last)
|
||||||
|
network.to(dtype).to(device)
|
||||||
|
|
||||||
|
networks.append(network)
|
||||||
else:
|
else:
|
||||||
network = None
|
networks = []
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
print(f"set optimizing: channels last")
|
print(f"set optimizing: channels last")
|
||||||
@ -2010,8 +2026,9 @@ def main(args):
|
|||||||
unet.to(memory_format=torch.channels_last)
|
unet.to(memory_format=torch.channels_last)
|
||||||
if clip_model is not None:
|
if clip_model is not None:
|
||||||
clip_model.to(memory_format=torch.channels_last)
|
clip_model.to(memory_format=torch.channels_last)
|
||||||
if network is not None:
|
if networks:
|
||||||
network.to(memory_format=torch.channels_last)
|
for network in networks:
|
||||||
|
network.to(memory_format=torch.channels_last)
|
||||||
if vgg16_model is not None:
|
if vgg16_model is not None:
|
||||||
vgg16_model.to(memory_format=torch.channels_last)
|
vgg16_model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
@ -2053,7 +2070,7 @@ def main(args):
|
|||||||
print(f"convert image to RGB from {image.mode}: {p}")
|
print(f"convert image to RGB from {image.mode}: {p}")
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def resize_images(imgs, size):
|
def resize_images(imgs, size):
|
||||||
@ -2154,12 +2171,12 @@ def main(args):
|
|||||||
# 1st stageのバッチを作成して呼び出す
|
# 1st stageのバッチを作成して呼び出す
|
||||||
print("process 1st stage1")
|
print("process 1st stage1")
|
||||||
batch_1st = []
|
batch_1st = []
|
||||||
for params1, (width, height, steps, scale, strength) in batch:
|
for params1, (width, height, steps, scale, negative_scale, strength) in batch:
|
||||||
width_1st = int(width * args.highres_fix_scale + .5)
|
width_1st = int(width * args.highres_fix_scale + .5)
|
||||||
height_1st = int(height * args.highres_fix_scale + .5)
|
height_1st = int(height * args.highres_fix_scale + .5)
|
||||||
width_1st = width_1st - width_1st % 32
|
width_1st = width_1st - width_1st % 32
|
||||||
height_1st = height_1st - height_1st % 32
|
height_1st = height_1st - height_1st % 32
|
||||||
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, strength)))
|
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
|
||||||
images_1st = process_batch(batch_1st, True, True)
|
images_1st = process_batch(batch_1st, True, True)
|
||||||
|
|
||||||
# 2nd stageのバッチを作成して以下処理する
|
# 2nd stageのバッチを作成して以下処理する
|
||||||
@ -2171,7 +2188,8 @@ def main(args):
|
|||||||
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
|
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
|
||||||
batch = batch_2nd
|
batch = batch_2nd
|
||||||
|
|
||||||
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width, height, steps, scale, strength) = batch[0]
|
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
|
||||||
|
height, steps, scale, negative_scale, strength) = batch[0]
|
||||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||||
|
|
||||||
prompts = []
|
prompts = []
|
||||||
@ -2247,7 +2265,7 @@ def main(args):
|
|||||||
guide_images = guide_images[0]
|
guide_images = guide_images[0]
|
||||||
|
|
||||||
# generate
|
# generate
|
||||||
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, strength, latents=start_code,
|
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
||||||
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
||||||
if highres_1st and not args.highres_fix_save_1st:
|
if highres_1st and not args.highres_fix_save_1st:
|
||||||
return images
|
return images
|
||||||
@ -2264,6 +2282,8 @@ def main(args):
|
|||||||
metadata.add_text("scale", str(scale))
|
metadata.add_text("scale", str(scale))
|
||||||
if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
metadata.add_text("negative-prompt", negative_prompt)
|
metadata.add_text("negative-prompt", negative_prompt)
|
||||||
|
if negative_scale is not None:
|
||||||
|
metadata.add_text("negative-scale", str(negative_scale))
|
||||||
if clip_prompt is not None:
|
if clip_prompt is not None:
|
||||||
metadata.add_text("clip-prompt", clip_prompt)
|
metadata.add_text("clip-prompt", clip_prompt)
|
||||||
|
|
||||||
@ -2316,6 +2336,7 @@ def main(args):
|
|||||||
width = args.W
|
width = args.W
|
||||||
height = args.H
|
height = args.H
|
||||||
scale = args.scale
|
scale = args.scale
|
||||||
|
negative_scale = args.negative_scale
|
||||||
steps = args.steps
|
steps = args.steps
|
||||||
seeds = None
|
seeds = None
|
||||||
strength = 0.8 if args.strength is None else args.strength
|
strength = 0.8 if args.strength is None else args.strength
|
||||||
@ -2358,6 +2379,15 @@ def main(args):
|
|||||||
print(f"scale: {scale}")
|
print(f"scale: {scale}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
m = re.match(r'nl ([\d\.]+|none|None)', parg, re.IGNORECASE)
|
||||||
|
if m: # negative scale
|
||||||
|
if m.group(1).lower() == 'none':
|
||||||
|
negative_scale = None
|
||||||
|
else:
|
||||||
|
negative_scale = float(m.group(1))
|
||||||
|
print(f"negative scale: {negative_scale}")
|
||||||
|
continue
|
||||||
|
|
||||||
m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE)
|
m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE)
|
||||||
if m: # strength
|
if m: # strength
|
||||||
strength = float(m.group(1))
|
strength = float(m.group(1))
|
||||||
@ -2420,8 +2450,9 @@ def main(args):
|
|||||||
print("Use previous image as guide image.")
|
print("Use previous image as guide image.")
|
||||||
guide_image = prev_image
|
guide_image = prev_image
|
||||||
|
|
||||||
|
# TODO named tupleか何かにする
|
||||||
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||||
(width, height, steps, scale, strength))
|
(width, height, steps, scale, negative_scale, strength))
|
||||||
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
||||||
process_batch(batch_data, highres_fix)
|
process_batch(batch_data, highres_fix)
|
||||||
batch_data.clear()
|
batch_data.clear()
|
||||||
@ -2481,19 +2512,24 @@ if __name__ == '__main__':
|
|||||||
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
||||||
parser.add_argument("--seed", type=int, default=None,
|
parser.add_argument("--seed", type=int, default=None,
|
||||||
help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed")
|
help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed")
|
||||||
parser.add_argument("--iter_same_seed", action='store_true', help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)')
|
parser.add_argument("--iter_same_seed", action='store_true',
|
||||||
|
help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)')
|
||||||
parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する')
|
parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する')
|
||||||
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
|
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
|
||||||
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
|
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
|
||||||
parser.add_argument("--diffusers_xformers", action='store_true',
|
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||||
help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
||||||
parser.add_argument("--opt_channels_last", action='store_true',
|
parser.add_argument("--opt_channels_last", action='store_true',
|
||||||
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
help='set channels last option to model / モデルにchannles lastを指定し最適化する')
|
||||||
parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
||||||
parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み')
|
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
||||||
parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||||
parser.add_argument("--network_dim", type=int, default=None,
|
help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||||
|
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||||
|
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
|
||||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
||||||
|
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||||
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||||
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
||||||
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
||||||
@ -2512,6 +2548,8 @@ if __name__ == '__main__':
|
|||||||
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
|
||||||
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
parser.add_argument("--highres_fix_save_1st", action='store_true',
|
||||||
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
|
||||||
|
parser.add_argument("--negative_scale", type=float, default=None,
|
||||||
|
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -261,4 +261,31 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
|||||||
v2 = False
|
v2 = False
|
||||||
v_parameterization = False
|
v_parameterization = False
|
||||||
|
|
||||||
return value, v2, v_parameterization
|
return value, v2, v_parameterization
|
||||||
|
|
||||||
|
###
|
||||||
|
### Gradio common GUI section
|
||||||
|
###
|
||||||
|
|
||||||
|
def gradio_advanced_training():
|
||||||
|
with gr.Row():
|
||||||
|
max_train_epochs = gr.Textbox(
|
||||||
|
label='Max train epoch',
|
||||||
|
placeholder='(Optional) Override number of epoch',
|
||||||
|
)
|
||||||
|
max_data_loader_n_workers = gr.Textbox(
|
||||||
|
label='Max num workers for DataLoader',
|
||||||
|
placeholder='(Optional) Override number of epoch. Default: 8',
|
||||||
|
)
|
||||||
|
return max_train_epochs, max_data_loader_n_workers
|
||||||
|
|
||||||
|
def run_cmd_advanced_training(**kwargs):
|
||||||
|
run_cmd = ''
|
||||||
|
max_train_epochs = kwargs.get('max_train_epochs', '')
|
||||||
|
max_data_loader_n_workers = kwargs.get('max_data_loader_n_workers', '')
|
||||||
|
if not max_train_epochs == '':
|
||||||
|
run_cmd += f' --max_train_epochs="{max_train_epochs}"'
|
||||||
|
if not max_data_loader_n_workers == '':
|
||||||
|
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"'
|
||||||
|
|
||||||
|
return run_cmd
|
@ -632,7 +632,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
|||||||
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||||
else:
|
else:
|
||||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||||
|
|
||||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||||
return new_sd
|
return new_sd
|
||||||
|
|
||||||
@ -886,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
|||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
info = vae.load_state_dict(converted_vae_checkpoint)
|
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||||
print("loadint vae:", info)
|
print("loading vae:", info)
|
||||||
|
|
||||||
# convert text_model
|
# convert text_model
|
||||||
if v2:
|
if v2:
|
||||||
@ -1105,12 +1105,12 @@ def load_vae(vae_id, dtype):
|
|||||||
|
|
||||||
if vae_id.endswith(".bin"):
|
if vae_id.endswith(".bin"):
|
||||||
# SD 1.5 VAE on Huggingface
|
# SD 1.5 VAE on Huggingface
|
||||||
vae_sd = torch.load(vae_id, map_location="cpu")
|
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
||||||
converted_vae_checkpoint = vae_sd
|
|
||||||
else:
|
else:
|
||||||
# StableDiffusion
|
# StableDiffusion
|
||||||
vae_model = torch.load(vae_id, map_location="cpu")
|
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
||||||
vae_sd = vae_model['state_dict']
|
else torch.load(vae_id, map_location="cpu"))
|
||||||
|
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
||||||
|
|
||||||
# vae only or full model
|
# vae only or full model
|
||||||
full_model = False
|
full_model = False
|
||||||
@ -1132,7 +1132,6 @@ def load_vae(vae_id, dtype):
|
|||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
return vae
|
return vae
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
@ -715,7 +715,10 @@ def debug_dataset(train_dataset):
|
|||||||
def glob_images(dir, base):
|
def glob_images(dir, base):
|
||||||
img_paths = []
|
img_paths = []
|
||||||
for ext in IMAGE_EXTENSIONS:
|
for ext in IMAGE_EXTENSIONS:
|
||||||
img_paths.extend(glob.glob(os.path.join(dir, base + ext)))
|
if base == '*':
|
||||||
|
img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext)))
|
||||||
|
else:
|
||||||
|
img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext))))
|
||||||
return img_paths
|
return img_paths
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
@ -744,6 +747,20 @@ def exists(val):
|
|||||||
def default(val, d):
|
def default(val, d):
|
||||||
return val if exists(val) else d
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
|
||||||
|
def model_hash(filename):
|
||||||
|
try:
|
||||||
|
with open(filename, "rb") as file:
|
||||||
|
import hashlib
|
||||||
|
m = hashlib.sha256()
|
||||||
|
|
||||||
|
file.seek(0x100000)
|
||||||
|
m.update(file.read(0x10000))
|
||||||
|
return m.hexdigest()[0:8]
|
||||||
|
except FileNotFoundError:
|
||||||
|
return 'NOFILE'
|
||||||
|
|
||||||
|
|
||||||
# flash attention forwards and backwards
|
# flash attention forwards and backwards
|
||||||
|
|
||||||
# https://arxiv.org/abs/2205.14135
|
# https://arxiv.org/abs/2205.14135
|
||||||
@ -1030,6 +1047,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
|
|
||||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||||
|
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||||
|
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
||||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||||
|
14
lora_gui.py
14
lora_gui.py
@ -19,7 +19,7 @@ from library.common_gui import (
|
|||||||
get_saveasfile_path,
|
get_saveasfile_path,
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
save_inference_file,
|
save_inference_file,
|
||||||
set_pretrained_model_name_or_path_input,
|
set_pretrained_model_name_or_path_input, gradio_advanced_training,run_cmd_advanced_training,
|
||||||
)
|
)
|
||||||
from library.dreambooth_folder_creation_gui import (
|
from library.dreambooth_folder_creation_gui import (
|
||||||
gradio_dreambooth_folder_creation_tab,
|
gradio_dreambooth_folder_creation_tab,
|
||||||
@ -81,6 +81,8 @@ def save_configuration(
|
|||||||
output_name,
|
output_name,
|
||||||
model_list,
|
model_list,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -165,6 +167,8 @@ def open_configuration(
|
|||||||
output_name,
|
output_name,
|
||||||
model_list,
|
model_list,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -233,6 +237,8 @@ def train_model(
|
|||||||
output_name,
|
output_name,
|
||||||
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
model_list, # Keep this. Yes, it is unused here but required given the common list used
|
||||||
max_token_length,
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -410,6 +416,7 @@ def train_model(
|
|||||||
run_cmd += f' --output_name="{output_name}"'
|
run_cmd += f' --output_name="{output_name}"'
|
||||||
if (int(max_token_length) > 75):
|
if (int(max_token_length) > 75):
|
||||||
run_cmd += f' --max_token_length={max_token_length}'
|
run_cmd += f' --max_token_length={max_token_length}'
|
||||||
|
run_cmd += run_cmd_advanced_training(max_train_epochs=max_train_epochs, max_data_loader_n_workers=max_data_loader_n_workers)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
# Run the command
|
# Run the command
|
||||||
@ -795,6 +802,7 @@ def lora_tab(
|
|||||||
],
|
],
|
||||||
value='75',
|
value='75',
|
||||||
)
|
)
|
||||||
|
max_train_epochs, max_data_loader_n_workers = gradio_advanced_training()
|
||||||
|
|
||||||
with gr.Tab('Tools'):
|
with gr.Tab('Tools'):
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
@ -854,7 +862,9 @@ def lora_tab(
|
|||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
output_name,
|
output_name,
|
||||||
model_list,
|
model_list,
|
||||||
max_token_length
|
max_token_length,
|
||||||
|
max_train_epochs,
|
||||||
|
max_data_loader_n_workers,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -135,7 +135,7 @@ def svd(args):
|
|||||||
if dir_name and not os.path.exists(dir_name):
|
if dir_name and not os.path.exists(dir_name):
|
||||||
os.makedirs(dir_name, exist_ok=True)
|
os.makedirs(dir_name, exist_ok=True)
|
||||||
|
|
||||||
lora_network_o.save_weights(args.save_to, save_dtype)
|
lora_network_o.save_weights(args.save_to, save_dtype, {})
|
||||||
print(f"LoRA weights are saved to: {args.save_to}")
|
print(f"LoRA weights are saved to: {args.save_to}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file, safe_open
|
||||||
self.weights_sd = load_file(file)
|
self.weights_sd = load_file(file)
|
||||||
else:
|
else:
|
||||||
self.weights_sd = torch.load(file, map_location='cpu')
|
self.weights_sd = torch.load(file, map_location='cpu')
|
||||||
@ -174,7 +174,10 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
def get_trainable_params(self):
|
def get_trainable_params(self):
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
def save_weights(self, file, dtype):
|
def save_weights(self, file, dtype, metadata):
|
||||||
|
if metadata is not None and len(metadata) == 0:
|
||||||
|
metadata = None
|
||||||
|
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
@ -185,6 +188,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
save_file(state_dict, file)
|
save_file(state_dict, file, metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
@ -90,4 +90,4 @@ if __name__ == '__main__':
|
|||||||
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
|
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert(args)
|
convert(args)
|
||||||
|
@ -236,4 +236,4 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
|
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
process(args)
|
process(args)
|
||||||
|
@ -134,10 +134,15 @@ def train(args):
|
|||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||||
|
|
||||||
|
# 学習ステップ数を計算する
|
||||||
|
if args.max_train_epochs is not None:
|
||||||
|
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||||
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
||||||
|
@ -126,10 +126,15 @@ def train(args):
|
|||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||||
|
|
||||||
|
# 学習ステップ数を計算する
|
||||||
|
if args.max_train_epochs is not None:
|
||||||
|
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||||
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||||
@ -194,9 +199,62 @@ def train(args):
|
|||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"ss_learning_rate": args.learning_rate,
|
||||||
|
"ss_text_encoder_lr": args.text_encoder_lr,
|
||||||
|
"ss_unet_lr": args.unet_lr,
|
||||||
|
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data
|
||||||
|
"ss_num_reg_images": train_dataset.num_reg_images,
|
||||||
|
"ss_num_batches_per_epoch": len(train_dataloader),
|
||||||
|
"ss_num_epochs": num_train_epochs,
|
||||||
|
"ss_batch_size_per_device": args.train_batch_size,
|
||||||
|
"ss_total_batch_size": total_batch_size,
|
||||||
|
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||||
|
"ss_max_train_steps": args.max_train_steps,
|
||||||
|
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
||||||
|
"ss_lr_scheduler": args.lr_scheduler,
|
||||||
|
"ss_network_module": args.network_module,
|
||||||
|
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
||||||
|
"ss_mixed_precision": args.mixed_precision,
|
||||||
|
"ss_full_fp16": bool(args.full_fp16),
|
||||||
|
"ss_v2": bool(args.v2),
|
||||||
|
"ss_resolution": args.resolution,
|
||||||
|
"ss_clip_skip": args.clip_skip,
|
||||||
|
"ss_max_token_length": args.max_token_length,
|
||||||
|
"ss_color_aug": bool(args.color_aug),
|
||||||
|
"ss_flip_aug": bool(args.flip_aug),
|
||||||
|
"ss_random_crop": bool(args.random_crop),
|
||||||
|
"ss_shuffle_caption": bool(args.shuffle_caption),
|
||||||
|
"ss_cache_latents": bool(args.cache_latents),
|
||||||
|
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
|
||||||
|
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
|
||||||
|
"ss_max_bucket_reso": args.max_bucket_reso,
|
||||||
|
"ss_seed": args.seed
|
||||||
|
}
|
||||||
|
|
||||||
|
# uncomment if another network is added
|
||||||
|
# for key, value in net_kwargs.items():
|
||||||
|
# metadata["ss_arg_" + key] = value
|
||||||
|
|
||||||
|
if args.pretrained_model_name_or_path is not None:
|
||||||
|
sd_model_name = args.pretrained_model_name_or_path
|
||||||
|
if os.path.exists(sd_model_name):
|
||||||
|
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
||||||
|
sd_model_name = os.path.basename(sd_model_name)
|
||||||
|
metadata["ss_sd_model_name"] = sd_model_name
|
||||||
|
|
||||||
|
if args.vae is not None:
|
||||||
|
vae_name = args.vae
|
||||||
|
if os.path.exists(vae_name):
|
||||||
|
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
||||||
|
vae_name = os.path.basename(vae_name)
|
||||||
|
metadata["ss_vae_name"] = vae_name
|
||||||
|
|
||||||
|
metadata = {k: str(v) for k, v in metadata.items()}
|
||||||
|
|
||||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
@ -208,6 +266,7 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
metadata["ss_epoch"] = str(epoch+1)
|
||||||
|
|
||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
@ -296,7 +355,7 @@ def train(args):
|
|||||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
unwrap_model(network).save_weights(ckpt_file, save_dtype)
|
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
||||||
@ -311,6 +370,8 @@ def train(args):
|
|||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
|
metadata["ss_epoch"] = str(num_train_epochs)
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
network = unwrap_model(network)
|
network = unwrap_model(network)
|
||||||
@ -330,7 +391,7 @@ def train(args):
|
|||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
network.save_weights(ckpt_file, save_dtype)
|
network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
@ -341,6 +402,7 @@ if __name__ == '__main__':
|
|||||||
train_util.add_dataset_arguments(parser, True, True)
|
train_util.add_dataset_arguments(parser, True, True)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
|
|
||||||
|
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user