Update to v17
New GUI
This commit is contained in:
parent
01eb9486d3
commit
5f1a465a45
59
README.md
59
README.md
@ -92,7 +92,15 @@ my_asd_dog_dreambooth
|
|||||||
`- dog8.png
|
`- dog8.png
|
||||||
```
|
```
|
||||||
|
|
||||||
## Execution
|
## GUI
|
||||||
|
|
||||||
|
There is now support for GUI based training using gradio. You can start the GUI interface by running:
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
python .\dreambooth_gui.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Manual Script Execution
|
||||||
|
|
||||||
### SD1.5 example
|
### SD1.5 example
|
||||||
|
|
||||||
@ -276,22 +284,21 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n
|
|||||||
## Options list
|
## Options list
|
||||||
|
|
||||||
```txt
|
```txt
|
||||||
usage: train_db_fixed.py [-h] [--v2] [--v_parameterization]
|
usage: train_db_fixed.py [-h] [--v2] [--v_parameterization] [--pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH]
|
||||||
[--pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH] [--fine_tuning]
|
[--fine_tuning] [--shuffle_caption] [--caption_extention CAPTION_EXTENTION]
|
||||||
[--shuffle_caption] [--caption_extention CAPTION_EXTENTION]
|
|
||||||
[--caption_extension CAPTION_EXTENSION] [--train_data_dir TRAIN_DATA_DIR]
|
[--caption_extension CAPTION_EXTENSION] [--train_data_dir TRAIN_DATA_DIR]
|
||||||
[--reg_data_dir REG_DATA_DIR] [--dataset_repeats DATASET_REPEATS] [--output_dir OUTPUT_DIR]
|
[--reg_data_dir REG_DATA_DIR] [--dataset_repeats DATASET_REPEATS] [--output_dir OUTPUT_DIR]
|
||||||
[--use_safetensors] [--save_every_n_epochs SAVE_EVERY_N_EPOCHS] [--save_state] [--resume RESUME]
|
[--use_safetensors] [--save_every_n_epochs SAVE_EVERY_N_EPOCHS] [--save_state] [--resume RESUME]
|
||||||
[--prior_loss_weight PRIOR_LOSS_WEIGHT] [--no_token_padding]
|
[--prior_loss_weight PRIOR_LOSS_WEIGHT] [--no_token_padding]
|
||||||
[--stop_text_encoder_training STOP_TEXT_ENCODER_TRAINING] [--color_aug] [--flip_aug]
|
[--stop_text_encoder_training STOP_TEXT_ENCODER_TRAINING] [--color_aug] [--flip_aug]
|
||||||
[--face_crop_aug_range FACE_CROP_AUG_RANGE] [--random_crop] [--debug_dataset]
|
[--face_crop_aug_range FACE_CROP_AUG_RANGE] [--random_crop] [--debug_dataset]
|
||||||
[--resolution RESOLUTION] [--train_batch_size TRAIN_BATCH_SIZE] [--use_8bit_adam]
|
[--resolution RESOLUTION] [--train_batch_size TRAIN_BATCH_SIZE] [--use_8bit_adam] [--mem_eff_attn]
|
||||||
[--mem_eff_attn] [--xformers] [--vae VAE] [--cache_latents] [--enable_bucket]
|
[--xformers] [--vae VAE] [--cache_latents] [--enable_bucket] [--min_bucket_reso MIN_BUCKET_RESO]
|
||||||
[--min_bucket_reso MIN_BUCKET_RESO] [--max_bucket_reso MAX_BUCKET_RESO]
|
[--max_bucket_reso MAX_BUCKET_RESO] [--learning_rate LEARNING_RATE]
|
||||||
[--learning_rate LEARNING_RATE] [--max_train_steps MAX_TRAIN_STEPS] [--seed SEED]
|
[--max_train_steps MAX_TRAIN_STEPS] [--seed SEED] [--gradient_checkpointing]
|
||||||
[--gradient_checkpointing] [--mixed_precision {no,fp16,bf16}]
|
[--mixed_precision {no,fp16,bf16}] [--full_fp16] [--save_precision {None,float,fp16,bf16}]
|
||||||
[--save_precision {None,float,fp16,bf16}] [--clip_skip CLIP_SKIP] [--logging_dir LOGGING_DIR]
|
[--clip_skip CLIP_SKIP] [--logging_dir LOGGING_DIR] [--log_prefix LOG_PREFIX]
|
||||||
[--log_prefix LOG_PREFIX] [--lr_scheduler LR_SCHEDULER] [--lr_warmup_steps LR_WARMUP_STEPS]
|
[--lr_scheduler LR_SCHEDULER] [--lr_warmup_steps LR_WARMUP_STEPS]
|
||||||
|
|
||||||
options:
|
options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
@ -303,7 +310,7 @@ options:
|
|||||||
--fine_tuning fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする
|
--fine_tuning fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする
|
||||||
--shuffle_caption shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする
|
--shuffle_caption shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする
|
||||||
--caption_extention CAPTION_EXTENTION
|
--caption_extention CAPTION_EXTENTION
|
||||||
extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子(スペルミスを 残してあります)
|
extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)
|
||||||
--caption_extension CAPTION_EXTENSION
|
--caption_extension CAPTION_EXTENSION
|
||||||
extension of caption files / 読み込むcaptionファイルの拡張子
|
extension of caption files / 読み込むcaptionファイルの拡張子
|
||||||
--train_data_dir TRAIN_DATA_DIR
|
--train_data_dir TRAIN_DATA_DIR
|
||||||
@ -314,10 +321,9 @@ options:
|
|||||||
repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数
|
repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数
|
||||||
--output_dir OUTPUT_DIR
|
--output_dir OUTPUT_DIR
|
||||||
directory to output trained model / 学習後のモデル出力先ディレクトリ
|
directory to output trained model / 学習後のモデル出力先ディレクトリ
|
||||||
--use_safetensors use safetensors format for StableDiffusion checkpoint /
|
--use_safetensors use safetensors format to save / checkpoint、モデルをsafetensors形式で保存する
|
||||||
StableDiffusionのcheckpointをsafetensors形式で保存する
|
|
||||||
--save_every_n_epochs SAVE_EVERY_N_EPOCHS
|
--save_every_n_epochs SAVE_EVERY_N_EPOCHS
|
||||||
save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存します
|
save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する
|
||||||
--save_state save training state additionally (including optimizer states etc.) /
|
--save_state save training state additionally (including optimizer states etc.) /
|
||||||
optimizerなど学習状態も含めたstateを追加で保存する
|
optimizerなど学習状態も含めたstateを追加で保存する
|
||||||
--resume RESUME saved state to resume training / 学習再開するモデルのstate
|
--resume RESUME saved state to resume training / 学習再開するモデルのstate
|
||||||
@ -333,17 +339,17 @@ options:
|
|||||||
enable face-centered crop augmentation and its range (e.g. 2.0,4.0) /
|
enable face-centered crop augmentation and its range (e.g. 2.0,4.0) /
|
||||||
学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)
|
学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)
|
||||||
--random_crop enable random crop (for style training in face-centered crop augmentation) /
|
--random_crop enable random crop (for style training in face-centered crop augmentation) /
|
||||||
ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)
|
ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)
|
||||||
--debug_dataset show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)
|
--debug_dataset show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)
|
||||||
--resolution RESOLUTION
|
--resolution RESOLUTION
|
||||||
resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高 さ'指定)
|
resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ' 指定)
|
||||||
--train_batch_size TRAIN_BATCH_SIZE
|
--train_batch_size TRAIN_BATCH_SIZE
|
||||||
batch size for training (1 means one train or reg data, not train/reg pair) /
|
batch size for training (1 means one train or reg data, not train/reg pair) /
|
||||||
学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)
|
学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)
|
||||||
--use_8bit_adam use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)
|
--use_8bit_adam use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインス トールが必要)
|
||||||
--mem_eff_attn use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う
|
--mem_eff_attn use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う
|
||||||
--xformers use xformers for CrossAttention / CrossAttentionにxformersを使う
|
--xformers use xformers for CrossAttention / CrossAttentionにxformersを使う
|
||||||
--vae VAE path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ
|
--vae VAE path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ
|
||||||
--cache_latents cache latents to reduce memory (augmentations must be disabled) /
|
--cache_latents cache latents to reduce memory (augmentations must be disabled) /
|
||||||
メモリ削減のためにlatentをcacheする(augmentationは使用不可)
|
メモリ削減のためにlatentをcacheする(augmentationは使用不可)
|
||||||
--enable_bucket enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする
|
--enable_bucket enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする
|
||||||
@ -360,19 +366,20 @@ options:
|
|||||||
enable gradient checkpointing / grandient checkpointingを有効にする
|
enable gradient checkpointing / grandient checkpointingを有効にする
|
||||||
--mixed_precision {no,fp16,bf16}
|
--mixed_precision {no,fp16,bf16}
|
||||||
use mixed precision / 混合精度を使う場合、その精度
|
use mixed precision / 混合精度を使う場合、その精度
|
||||||
|
--full_fp16 fp16 training including gradients / 勾配も含めてfp16で学習する
|
||||||
--save_precision {None,float,fp16,bf16}
|
--save_precision {None,float,fp16,bf16}
|
||||||
precision in saving (available in StableDiffusion checkpoint) /
|
precision in saving (available in StableDiffusion checkpoint) /
|
||||||
保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)
|
保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)
|
||||||
--clip_skip CLIP_SKIP
|
--clip_skip CLIP_SKIP
|
||||||
use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を 用いる(nは1以上)
|
use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)
|
||||||
--logging_dir LOGGING_DIR
|
--logging_dir LOGGING_DIR
|
||||||
enable logging and output TensorBoard log to this directory /
|
enable logging and output TensorBoard log to this directory /
|
||||||
ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する
|
ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する
|
||||||
--log_prefix LOG_PREFIX
|
--log_prefix LOG_PREFIX
|
||||||
add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列
|
add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列
|
||||||
--lr_scheduler LR_SCHEDULER
|
--lr_scheduler LR_SCHEDULER
|
||||||
scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts,
|
scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial,
|
||||||
polynomial, constant (default), constant_with_warmup
|
constant (default), constant_with_warmup
|
||||||
--lr_warmup_steps LR_WARMUP_STEPS
|
--lr_warmup_steps LR_WARMUP_STEPS
|
||||||
Number of steps for the warmup in the lr scheduler (default is 0) /
|
Number of steps for the warmup in the lr scheduler (default is 0) /
|
||||||
学習率のスケジューラをウォームアップするステップ数(デフォルト0)
|
学習率のスケジューラをウォームアップするステップ数(デフォルト0)
|
||||||
|
@ -34,7 +34,9 @@ def save_variables(
|
|||||||
cache_latent,
|
cache_latent,
|
||||||
caption_extention,
|
caption_extention,
|
||||||
use_safetensors,
|
use_safetensors,
|
||||||
enable_bucket
|
enable_bucket,
|
||||||
|
gradient_checkpointing,
|
||||||
|
full_fp16
|
||||||
):
|
):
|
||||||
# Return the values of the variables as a dictionary
|
# Return the values of the variables as a dictionary
|
||||||
variables = {
|
variables = {
|
||||||
@ -61,7 +63,9 @@ def save_variables(
|
|||||||
"cache_latent": cache_latent,
|
"cache_latent": cache_latent,
|
||||||
"caption_extention": caption_extention,
|
"caption_extention": caption_extention,
|
||||||
"use_safetensors": use_safetensors,
|
"use_safetensors": use_safetensors,
|
||||||
"enable_bucket": enable_bucket
|
"enable_bucket": enable_bucket,
|
||||||
|
"gradient_checkpointing": gradient_checkpointing,
|
||||||
|
"full_fp16": full_fp16
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save the data to the selected file
|
# Save the data to the selected file
|
||||||
@ -100,6 +104,8 @@ def load_variables(file_path):
|
|||||||
my_data.get("caption_extention", None),
|
my_data.get("caption_extention", None),
|
||||||
my_data.get("use_safetensors", None),
|
my_data.get("use_safetensors", None),
|
||||||
my_data.get("enable_bucket", None),
|
my_data.get("enable_bucket", None),
|
||||||
|
my_data.get("gradient_checkpointing", None),
|
||||||
|
my_data.get("full_fp16", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -127,7 +133,9 @@ def train_model(
|
|||||||
cache_latent,
|
cache_latent,
|
||||||
caption_extention,
|
caption_extention,
|
||||||
use_safetensors,
|
use_safetensors,
|
||||||
enable_bucket
|
enable_bucket,
|
||||||
|
gradient_checkpointing,
|
||||||
|
full_fp16
|
||||||
):
|
):
|
||||||
def save_inference_file(output_dir, v2, v_model):
|
def save_inference_file(output_dir, v2, v_model):
|
||||||
# Copy inference model for v2 if required
|
# Copy inference model for v2 if required
|
||||||
@ -189,6 +197,10 @@ def train_model(
|
|||||||
run_cmd += " --use_safetensors"
|
run_cmd += " --use_safetensors"
|
||||||
if enable_bucket:
|
if enable_bucket:
|
||||||
run_cmd += " --enable_bucket"
|
run_cmd += " --enable_bucket"
|
||||||
|
if gradient_checkpointing:
|
||||||
|
run_cmd += " --gradient_checkpointing"
|
||||||
|
if full_fp16:
|
||||||
|
run_cmd += " --full_fp16"
|
||||||
run_cmd += f" --pretrained_model_name_or_path={pretrained_model_name_or_path}"
|
run_cmd += f" --pretrained_model_name_or_path={pretrained_model_name_or_path}"
|
||||||
run_cmd += f" --train_data_dir={train_data_dir}"
|
run_cmd += f" --train_data_dir={train_data_dir}"
|
||||||
run_cmd += f" --reg_data_dir={reg_data_dir}"
|
run_cmd += f" --reg_data_dir={reg_data_dir}"
|
||||||
@ -287,8 +299,8 @@ with interface:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
config_file_name = gr.Textbox(
|
config_file_name = gr.Textbox(
|
||||||
label="Config file name")
|
label="Config file name")
|
||||||
b1 = gr.Button("Load config")
|
button_load_config = gr.Button("Load config")
|
||||||
b2 = gr.Button("Save config")
|
button_save_config = gr.Button("Save config")
|
||||||
with gr.Tab("Source model"):
|
with gr.Tab("Source model"):
|
||||||
# Define the input elements
|
# Define the input elements
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -399,6 +411,12 @@ with interface:
|
|||||||
cache_latent_input = gr.Checkbox(
|
cache_latent_input = gr.Checkbox(
|
||||||
label="Cache latent", value=True
|
label="Cache latent", value=True
|
||||||
)
|
)
|
||||||
|
gradient_checkpointing_input = gr.Checkbox(
|
||||||
|
label="Gradient checkpointing", value=False
|
||||||
|
)
|
||||||
|
full_fp16_input = gr.Checkbox(
|
||||||
|
label="Full fp16 training (experimental)", value=False
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Tab("Model conversion"):
|
with gr.Tab("Model conversion"):
|
||||||
convert_to_safetensors_input = gr.Checkbox(
|
convert_to_safetensors_input = gr.Checkbox(
|
||||||
@ -408,9 +426,9 @@ with interface:
|
|||||||
label="Convert to CKPT", value=False
|
label="Convert to CKPT", value=False
|
||||||
)
|
)
|
||||||
|
|
||||||
b3 = gr.Button("Run")
|
button_run = gr.Button("Run")
|
||||||
|
|
||||||
b1.click(
|
button_load_config.click(
|
||||||
load_variables,
|
load_variables,
|
||||||
inputs=[config_file_name],
|
inputs=[config_file_name],
|
||||||
outputs=[
|
outputs=[
|
||||||
@ -437,11 +455,13 @@ with interface:
|
|||||||
cache_latent_input,
|
cache_latent_input,
|
||||||
caption_extention_input,
|
caption_extention_input,
|
||||||
use_safetensors_input,
|
use_safetensors_input,
|
||||||
enable_bucket_input
|
enable_bucket_input,
|
||||||
|
gradient_checkpointing_input,
|
||||||
|
full_fp16_input
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
b2.click(
|
button_save_config.click(
|
||||||
save_variables,
|
save_variables,
|
||||||
inputs=[
|
inputs=[
|
||||||
config_file_name,
|
config_file_name,
|
||||||
@ -468,10 +488,12 @@ with interface:
|
|||||||
cache_latent_input,
|
cache_latent_input,
|
||||||
caption_extention_input,
|
caption_extention_input,
|
||||||
use_safetensors_input,
|
use_safetensors_input,
|
||||||
enable_bucket_input
|
enable_bucket_input,
|
||||||
|
gradient_checkpointing_input,
|
||||||
|
full_fp16_input
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
b3.click(
|
button_run.click(
|
||||||
train_model,
|
train_model,
|
||||||
inputs=[
|
inputs=[
|
||||||
pretrained_model_name_or_path_input,
|
pretrained_model_name_or_path_input,
|
||||||
@ -497,7 +519,9 @@ with interface:
|
|||||||
cache_latent_input,
|
cache_latent_input,
|
||||||
caption_extention_input,
|
caption_extention_input,
|
||||||
use_safetensors_input,
|
use_safetensors_input,
|
||||||
enable_bucket_input
|
enable_bucket_input,
|
||||||
|
gradient_checkpointing_input,
|
||||||
|
full_fp16_input
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -813,7 +813,7 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region 自作のモデル読み書き
|
# region 自作のモデル読み書きなど
|
||||||
|
|
||||||
def is_safetensors(path):
|
def is_safetensors(path):
|
||||||
return os.path.splitext(path)[1].lower() == '.safetensors'
|
return os.path.splitext(path)[1].lower() == '.safetensors'
|
||||||
@ -1046,7 +1046,7 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
|
|||||||
return key_count
|
return key_count
|
||||||
|
|
||||||
|
|
||||||
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None):
|
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
||||||
if vae is None:
|
if vae is None:
|
||||||
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
@ -1059,7 +1059,7 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
|
|||||||
feature_extractor=None,
|
feature_extractor=None,
|
||||||
requires_safety_checker=None,
|
requires_safety_checker=None,
|
||||||
)
|
)
|
||||||
pipeline.save_pretrained(output_dir)
|
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
||||||
|
|
||||||
|
|
||||||
VAE_PREFIX = "first_stage_model."
|
VAE_PREFIX = "first_stage_model."
|
||||||
@ -1117,6 +1117,7 @@ def get_epoch_ckpt_name(use_safetensors, epoch):
|
|||||||
def get_last_ckpt_name(use_safetensors):
|
def get_last_ckpt_name(use_safetensors):
|
||||||
return f"last" + (".safetensors" if use_safetensors else ".ckpt")
|
return f"last" + (".safetensors" if use_safetensors else ".ckpt")
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
accelerate==0.14.0
|
accelerate==0.15.0
|
||||||
transformers==4.25.1
|
transformers==4.25.1
|
||||||
ftfy
|
ftfy
|
||||||
albumentations
|
albumentations
|
||||||
opencv-python
|
opencv-python
|
||||||
einops
|
einops
|
||||||
diffusers[torch]==0.9.0
|
diffusers[torch]==0.10.2
|
||||||
pytorch_lightning
|
pytorch_lightning
|
||||||
bitsandbytes==0.35.0
|
bitsandbytes==0.35.0
|
||||||
tensorboard
|
tensorboard
|
||||||
safetensors==0.2.5
|
safetensors==0.2.6
|
||||||
gradio
|
gradio
|
||||||
altair
|
altair
|
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -23,7 +22,7 @@ del theta_prune
|
|||||||
|
|
||||||
if args.half:
|
if args.half:
|
||||||
print("Halving model...")
|
print("Halving model...")
|
||||||
state_dict = {k: v.half() for k, v in theta.items()}
|
state_dict = {k: v.half() for k, v in tqdm(theta.items(), desc="Halving weights")}
|
||||||
else:
|
else:
|
||||||
state_dict = theta
|
state_dict = theta
|
||||||
|
|
||||||
|
@ -13,6 +13,9 @@
|
|||||||
# v12: stop train text encode, tqdm smoothing
|
# v12: stop train text encode, tqdm smoothing
|
||||||
# v13: bug fix
|
# v13: bug fix
|
||||||
# v14: refactor to use model_util, add log prefix, support safetensors, support vae loading, keep vae in CPU to save the loaded vae
|
# v14: refactor to use model_util, add log prefix, support safetensors, support vae loading, keep vae in CPU to save the loaded vae
|
||||||
|
# v15: model_util update
|
||||||
|
# v16: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0
|
||||||
|
# v17: add fp16 gradient training (experimental)
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
@ -43,7 +46,7 @@ import model_util
|
|||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||||
|
|
||||||
# CLIP_ID_L14_336 = "openai/clip-vit-large-patch14-336"
|
# CLIP_ID_L14_336 = "openai/clip-vit-large-patch14-336"
|
||||||
|
|
||||||
@ -596,10 +599,6 @@ def replace_unet_cross_attn_to_memory_efficient():
|
|||||||
|
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
|
||||||
# diffusers 0.6.0
|
|
||||||
if type(self.to_out) is torch.nn.Sequential:
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
# diffusers 0.7.0~
|
# diffusers 0.7.0~
|
||||||
out = self.to_out[0](out)
|
out = self.to_out[0](out)
|
||||||
out = self.to_out[1](out)
|
out = self.to_out[1](out)
|
||||||
@ -633,10 +632,6 @@ def replace_unet_cross_attn_to_xformers():
|
|||||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||||
# out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
# out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
# diffusers 0.6.0
|
|
||||||
if type(self.to_out) is torch.nn.Sequential:
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
# diffusers 0.7.0~
|
# diffusers 0.7.0~
|
||||||
out = self.to_out[0](out)
|
out = self.to_out[0](out)
|
||||||
out = self.to_out[1](out)
|
out = self.to_out[1](out)
|
||||||
@ -821,6 +816,19 @@ def train(args):
|
|||||||
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision,
|
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision,
|
||||||
log_with=log_with, logging_dir=logging_dir)
|
log_with=log_with, logging_dir=logging_dir)
|
||||||
|
|
||||||
|
# accelerateの互換性問題を解決する
|
||||||
|
accelerator_0_15 = True
|
||||||
|
try:
|
||||||
|
accelerator.unwrap_model("dummy", True)
|
||||||
|
print("Using accelerator 0.15.0 or above.")
|
||||||
|
except TypeError:
|
||||||
|
accelerator_0_15 = False
|
||||||
|
|
||||||
|
def unwrap_model(model):
|
||||||
|
if accelerator_0_15:
|
||||||
|
return accelerator.unwrap_model(model, True)
|
||||||
|
return accelerator.unwrap_model(model)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
if args.mixed_precision == "fp16":
|
if args.mixed_precision == "fp16":
|
||||||
@ -914,6 +922,13 @@ def train(args):
|
|||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
||||||
|
|
||||||
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
|
if args.full_fp16:
|
||||||
|
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||||
|
print("enable full fp16 training.")
|
||||||
|
unet.to(weight_dtype)
|
||||||
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||||
@ -921,6 +936,15 @@ def train(args):
|
|||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
|
if args.full_fp16:
|
||||||
|
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||||
|
|
||||||
|
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||||
|
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||||
|
|
||||||
|
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
if args.resume is not None:
|
if args.resume is not None:
|
||||||
print(f"resume training from state: {args.resume}")
|
print(f"resume training from state: {args.resume}")
|
||||||
@ -946,8 +970,8 @@ def train(args):
|
|||||||
|
|
||||||
# v12で更新:clip_sample=Falseに
|
# v12で更新:clip_sample=Falseに
|
||||||
# Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる
|
# Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる
|
||||||
# 既存の1.4/1.5/2.0はすべてschdulerのconfigは(クラス名を除いて)同じ
|
# 既存の1.4/1.5/2.0/2.1はすべてschdulerのconfigは(クラス名を除いて)同じ
|
||||||
# よくソースを見たら学習時は関係ないや(;'∀')
|
# よくソースを見たら学習時はclip_sampleは関係ないや(;'∀')
|
||||||
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
||||||
num_train_timesteps=1000, clip_sample=False)
|
num_train_timesteps=1000, clip_sample=False)
|
||||||
|
|
||||||
@ -1006,31 +1030,8 @@ def train(args):
|
|||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
# v-parameterization training
|
# v-parameterization training
|
||||||
# こうしたい:
|
# Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う
|
||||||
# target = noise_scheduler.get_v(latents, noise, timesteps)
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
|
||||||
# StabilityAiのddpm.pyのコード:
|
|
||||||
# elif self.parameterization == "v":
|
|
||||||
# target = self.get_v(x_start, noise, t)
|
|
||||||
# ...
|
|
||||||
# def get_v(self, x, noise, t):
|
|
||||||
# return (
|
|
||||||
# extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
|
|
||||||
# extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
|
||||||
# )
|
|
||||||
|
|
||||||
# scheduling_ddim.pyのコード:
|
|
||||||
# elif self.config.prediction_type == "v_prediction":
|
|
||||||
# pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
|
||||||
# # predict V
|
|
||||||
# model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
|
||||||
|
|
||||||
# これでいいかな?:
|
|
||||||
alpha_prod_t = noise_scheduler.alphas_cumprod[timesteps]
|
|
||||||
beta_prod_t = 1 - alpha_prod_t
|
|
||||||
alpha_prod_t = torch.reshape(alpha_prod_t, (len(alpha_prod_t), 1, 1, 1)) # broadcastされないらしいのでreshape
|
|
||||||
beta_prod_t = torch.reshape(beta_prod_t, (len(beta_prod_t), 1, 1, 1))
|
|
||||||
target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents
|
|
||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
@ -1081,13 +1082,14 @@ def train(args):
|
|||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(args.use_safetensors, epoch + 1))
|
ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(args.use_safetensors, epoch + 1))
|
||||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
|
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet),
|
||||||
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype, vae)
|
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype, vae)
|
||||||
else:
|
else:
|
||||||
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
|
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder),
|
model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder),
|
||||||
accelerator.unwrap_model(unet), args.pretrained_model_name_or_path)
|
unwrap_model(unet), args.pretrained_model_name_or_path,
|
||||||
|
use_safetensors=args.use_safetensors)
|
||||||
|
|
||||||
if args.save_state:
|
if args.save_state:
|
||||||
print("saving state.")
|
print("saving state.")
|
||||||
@ -1095,8 +1097,8 @@ def train(args):
|
|||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
unet = accelerator.unwrap_model(unet)
|
unet = unwrap_model(unet)
|
||||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
text_encoder = unwrap_model(text_encoder)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
@ -1118,7 +1120,8 @@ def train(args):
|
|||||||
print(f"save trained model as Diffusers to {args.output_dir}")
|
print(f"save trained model as Diffusers to {args.output_dir}")
|
||||||
out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
|
out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path)
|
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path,
|
||||||
|
use_safetensors=args.use_safetensors)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
@ -1147,9 +1150,9 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--output_dir", type=str, default=None,
|
parser.add_argument("--output_dir", type=str, default=None,
|
||||||
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||||
parser.add_argument("--use_safetensors", action='store_true',
|
parser.add_argument("--use_safetensors", action='store_true',
|
||||||
help="use safetensors format for StableDiffusion checkpoint / StableDiffusionのcheckpointをsafetensors形式で保存する")
|
help="use safetensors format to save / checkpoint、モデルをsafetensors形式で保存する")
|
||||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存します")
|
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||||
parser.add_argument("--save_state", action="store_true",
|
parser.add_argument("--save_state", action="store_true",
|
||||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||||
@ -1191,6 +1194,7 @@ if __name__ == '__main__':
|
|||||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||||
|
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
||||||
parser.add_argument("--save_precision", type=str, default=None,
|
parser.add_argument("--save_precision", type=str, default=None,
|
||||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)")
|
choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)")
|
||||||
parser.add_argument("--clip_skip", type=int, default=None,
|
parser.add_argument("--clip_skip", type=int, default=None,
|
||||||
|
Loading…
Reference in New Issue
Block a user