Update train_db_fixed to v9
This commit is contained in:
parent
f56340d53e
commit
0e8b993def
16
README.md
16
README.md
@ -113,9 +113,10 @@ accelerate launch --num_cpu_threads_per_process 6 train_db_fixed-ber.py `
|
|||||||
--cache_latents `
|
--cache_latents `
|
||||||
--save_every_n_epochs=1 `
|
--save_every_n_epochs=1 `
|
||||||
--fine_tuning `
|
--fine_tuning `
|
||||||
|
--enable_bucket `
|
||||||
--dataset_repeats=200 `
|
--dataset_repeats=200 `
|
||||||
--seed=23 `
|
--seed=23 `
|
||||||
--save_half
|
---save_precision="fp16"
|
||||||
```
|
```
|
||||||
|
|
||||||
Refer to this url for more details about finetuning: https://note.com/kohya_ss/n/n1269f1e1a54e
|
Refer to this url for more details about finetuning: https://note.com/kohya_ss/n/n1269f1e1a54e
|
||||||
@ -125,7 +126,12 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n
|
|||||||
* 11/7 (v7): Text Encoder supports checkpoint files in different storage formats (it is converted at the time of import, so export will be in normal format). Changed the average value of EPOCH loss to output to the screen. Added a function to save epoch and global step in checkpoint in SD format (add values if there is existing data). The reg_data_dir option is enabled during fine tuning (fine tuning while mixing regularized images). Added dataset_repeats option that is valid for fine tuning (specified when the number of teacher images is small and the epoch is extremely short).
|
* 11/7 (v7): Text Encoder supports checkpoint files in different storage formats (it is converted at the time of import, so export will be in normal format). Changed the average value of EPOCH loss to output to the screen. Added a function to save epoch and global step in checkpoint in SD format (add values if there is existing data). The reg_data_dir option is enabled during fine tuning (fine tuning while mixing regularized images). Added dataset_repeats option that is valid for fine tuning (specified when the number of teacher images is small and the epoch is extremely short).
|
||||||
* 11/9 (v8): supports Diffusers 0.7.2. To upgrade diffusers run `pip install --upgrade diffusers[torch]`
|
* 11/9 (v8): supports Diffusers 0.7.2. To upgrade diffusers run `pip install --upgrade diffusers[torch]`
|
||||||
* 11/14 (diffusers_fine_tuning v2):
|
* 11/14 (diffusers_fine_tuning v2):
|
||||||
- script name is now fine_tune.py.
|
- script name is now fine_tune.py.
|
||||||
- Added option to learn Text Encoder --train_text_encoder.
|
- Added option to learn Text Encoder --train_text_encoder.
|
||||||
- The data format of checkpoint at the time of saving can be specified with the --save_precision option. You can choose float, fp16, and bf16.
|
- The data format of checkpoint at the time of saving can be specified with the --save_precision option. You can choose float, fp16, and bf16.
|
||||||
- Added a --save_state option to save the learning state (optimizer, etc.) in the middle. It can be resumed with the --resume option.
|
- Added a --save_state option to save the learning state (optimizer, etc.) in the middle. It can be resumed with the --resume option.
|
||||||
|
* 11/18 (v9):
|
||||||
|
- Added support for Aspect Ratio Bucketing (enable_bucket option). (--enable_bucket)
|
||||||
|
- Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision)
|
||||||
|
- Added support for saving learning state (--save_state, --resume)
|
||||||
|
- Added support for logging (--logging_dir)
|
||||||
|
@ -2,9 +2,12 @@
|
|||||||
#
|
#
|
||||||
# Usefull to create base caption that will be augmented on a per image basis
|
# Usefull to create base caption that will be augmented on a per image basis
|
||||||
|
|
||||||
$folder = "D:\dreambooth\train_sylvia_ritter\raw_data\all-images\"
|
$folder = "D:\some\folder\location\"
|
||||||
$file_pattern="*.*"
|
$file_pattern="*.*"
|
||||||
$text_fir_file="a digital painting of xxx, by silvery trait"
|
$caption_text="some caption text"
|
||||||
|
|
||||||
$files = Get-ChildItem $folder$file_pattern
|
$files = Get-ChildItem $folder$file_pattern -Include *.png,*.jpg,*.webp -File
|
||||||
foreach ($file in $files) {New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $text_fir_file}
|
foreach ($file in $files)
|
||||||
|
{
|
||||||
|
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text
|
||||||
|
}
|
20
examples/caption_subfolders.ps1
Normal file
20
examples/caption_subfolders.ps1
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# This powershell script will create a text file for each files in the folder
|
||||||
|
#
|
||||||
|
# Usefull to create base caption that will be augmented on a per image basis
|
||||||
|
|
||||||
|
$folder = "D:\test\t2\"
|
||||||
|
$file_pattern="*.*"
|
||||||
|
$text_fir_file="bigeyes style"
|
||||||
|
|
||||||
|
foreach ($file in Get-ChildItem $folder\$file_pattern -File)
|
||||||
|
{
|
||||||
|
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $text_fir_file
|
||||||
|
}
|
||||||
|
|
||||||
|
foreach($directory in Get-ChildItem -path $folder -Directory)
|
||||||
|
{
|
||||||
|
foreach ($file in Get-ChildItem $folder\$directory\$file_pattern)
|
||||||
|
{
|
||||||
|
New-Item -ItemType file -Path $folder\$directory -Name "$($file.BaseName).txt" -Value $text_fir_file
|
||||||
|
}
|
||||||
|
}
|
87
examples/kohya-1-folders.ps1
Normal file
87
examples/kohya-1-folders.ps1
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# This powershell script will create a model using the fine tuning dreambooth method. It will require landscape,
|
||||||
|
# portrait and square images.
|
||||||
|
#
|
||||||
|
# Adjust the script to your own needs
|
||||||
|
|
||||||
|
# Sylvia Ritter
|
||||||
|
# variable values
|
||||||
|
$pretrained_model_name_or_path = "D:\models\v1-5-pruned-mse-vae.ckpt"
|
||||||
|
$data_dir = "D:\test\squat"
|
||||||
|
$train_dir = "D:\test\"
|
||||||
|
$resolution = "512,512"
|
||||||
|
|
||||||
|
$image_num = Get-ChildItem $data_dir -Recurse -File -Include *.png | Measure-Object | %{$_.Count}
|
||||||
|
|
||||||
|
Write-Output "image_num: $image_num"
|
||||||
|
|
||||||
|
$learning_rate = 1e-6
|
||||||
|
$dataset_repeats = 40
|
||||||
|
$train_batch_size = 8
|
||||||
|
$epoch = 1
|
||||||
|
$save_every_n_epochs=1
|
||||||
|
$mixed_precision="fp16"
|
||||||
|
$num_cpu_threads_per_process=6
|
||||||
|
|
||||||
|
# You should not have to change values past this point
|
||||||
|
|
||||||
|
$output_dir = $train_dir + "\model"
|
||||||
|
$repeats = $image_num * $dataset_repeats
|
||||||
|
$mts = [Math]::Ceiling($repeats / $train_batch_size * $epoch)
|
||||||
|
|
||||||
|
Write-Output "Repeats: $repeats"
|
||||||
|
|
||||||
|
.\venv\Scripts\activate
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed-ber.py `
|
||||||
|
--pretrained_model_name_or_path=$pretrained_model_name_or_path `
|
||||||
|
--train_data_dir=$data_dir `
|
||||||
|
--output_dir=$output_dir `
|
||||||
|
--resolution=$resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$mts `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$dataset_repeats `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
# 2nd pass at half the dataset repeat value
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$output_dir"\last.ckpt" `
|
||||||
|
--train_data_dir=$data_dir `
|
||||||
|
--output_dir=$output_dir"2" `
|
||||||
|
--resolution=$resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$([Math]::Ceiling($mts/2)) `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed-ber.py `
|
||||||
|
--pretrained_model_name_or_path=$output_dir"\last.ckpt" `
|
||||||
|
--train_data_dir=$data_dir `
|
||||||
|
--output_dir=$output_dir"2" `
|
||||||
|
--resolution=$resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$mts `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$dataset_repeats `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
154
examples/kohya-3-folders.ps1
Normal file
154
examples/kohya-3-folders.ps1
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
# This powershell script will create a model using the fine tuning dreambooth method. It will require landscape,
|
||||||
|
# portrait and square images.
|
||||||
|
#
|
||||||
|
# Adjust the script to your own needs
|
||||||
|
|
||||||
|
# Sylvia Ritter
|
||||||
|
# variable values
|
||||||
|
$pretrained_model_name_or_path = "D:\models\v1-5-pruned-mse-vae.ckpt"
|
||||||
|
$train_dir = "D:\dreambooth\train_sylvia_ritter\raw_data"
|
||||||
|
|
||||||
|
$landscape_image_num = 4
|
||||||
|
$portrait_image_num = 25
|
||||||
|
$square_image_num = 2
|
||||||
|
|
||||||
|
$learning_rate = 1e-6
|
||||||
|
$dataset_repeats = 120
|
||||||
|
$train_batch_size = 4
|
||||||
|
$epoch = 1
|
||||||
|
$save_every_n_epochs=1
|
||||||
|
$mixed_precision="fp16"
|
||||||
|
$num_cpu_threads_per_process=6
|
||||||
|
|
||||||
|
$landscape_folder_name = "landscape-pp"
|
||||||
|
$landscape_resolution = "832,512"
|
||||||
|
$portrait_folder_name = "portrait-pp"
|
||||||
|
$portrait_resolution = "448,896"
|
||||||
|
$square_folder_name = "square-pp"
|
||||||
|
$square_resolution = "512,512"
|
||||||
|
|
||||||
|
# You should not have to change values past this point
|
||||||
|
|
||||||
|
$landscape_data_dir = $train_dir + "\" + $landscape_folder_name
|
||||||
|
$portrait_data_dir = $train_dir + "\" + $portrait_folder_name
|
||||||
|
$square_data_dir = $train_dir + "\" + $square_folder_name
|
||||||
|
$landscape_output_dir = $train_dir + "\model-l"
|
||||||
|
$portrait_output_dir = $train_dir + "\model-lp"
|
||||||
|
$square_output_dir = $train_dir + "\model-lps"
|
||||||
|
|
||||||
|
$landscape_repeats = $landscape_image_num * $dataset_repeats
|
||||||
|
$portrait_repeats = $portrait_image_num * $dataset_repeats
|
||||||
|
$square_repeats = $square_image_num * $dataset_repeats
|
||||||
|
|
||||||
|
$landscape_mts = [Math]::Ceiling($landscape_repeats / $train_batch_size * $epoch)
|
||||||
|
$portrait_mts = [Math]::Ceiling($portrait_repeats / $train_batch_size * $epoch)
|
||||||
|
$square_mts = [Math]::Ceiling($square_repeats / $train_batch_size * $epoch)
|
||||||
|
|
||||||
|
# Write-Output $landscape_repeats
|
||||||
|
|
||||||
|
.\venv\Scripts\activate
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$pretrained_model_name_or_path `
|
||||||
|
--train_data_dir=$landscape_data_dir `
|
||||||
|
--output_dir=$landscape_output_dir `
|
||||||
|
--resolution=$landscape_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$landscape_mts `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$dataset_repeats `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$landscape_output_dir"\last.ckpt" `
|
||||||
|
--train_data_dir=$portrait_data_dir `
|
||||||
|
--output_dir=$portrait_output_dir `
|
||||||
|
--resolution=$portrait_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$portrait_mts `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$dataset_repeats `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$portrait_output_dir"\last.ckpt" `
|
||||||
|
--train_data_dir=$square_data_dir `
|
||||||
|
--output_dir=$square_output_dir `
|
||||||
|
--resolution=$square_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$square_mts `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$dataset_repeats `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
# 2nd pass at half the dataset repeat value
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$square_output_dir"\last.ckpt" `
|
||||||
|
--train_data_dir=$landscape_data_dir `
|
||||||
|
--output_dir=$landscape_output_dir"2" `
|
||||||
|
--resolution=$landscape_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$([Math]::Ceiling($landscape_mts/2)) `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$landscape_output_dir"2\last.ckpt" `
|
||||||
|
--train_data_dir=$portrait_data_dir `
|
||||||
|
--output_dir=$portrait_output_dir"2" `
|
||||||
|
--resolution=$portrait_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$([Math]::Ceiling($portrait_mts/2)) `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$portrait_output_dir"2\last.ckpt" `
|
||||||
|
--train_data_dir=$square_data_dir `
|
||||||
|
--output_dir=$square_output_dir"2" `
|
||||||
|
--resolution=$square_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$([Math]::Ceiling($square_mts/2)) `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
@ -55,7 +55,7 @@ accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\
|
|||||||
--use_8bit_adam --xformers `
|
--use_8bit_adam --xformers `
|
||||||
--mixed_precision=$mixed_precision `
|
--mixed_precision=$mixed_precision `
|
||||||
--save_every_n_epochs=$save_every_n_epochs `
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
--save_half
|
--save_precision="fp16"
|
||||||
|
|
||||||
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\kohya_ss\diffusers_fine_tuning\fine_tune.py `
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\kohya_ss\diffusers_fine_tuning\fine_tune.py `
|
||||||
--pretrained_model_name_or_path=$train_dir"\fine_tuned\last.ckpt" `
|
--pretrained_model_name_or_path=$train_dir"\fine_tuned\last.ckpt" `
|
||||||
@ -69,4 +69,4 @@ accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\
|
|||||||
--use_8bit_adam --xformers `
|
--use_8bit_adam --xformers `
|
||||||
--mixed_precision=$mixed_precision `
|
--mixed_precision=$mixed_precision `
|
||||||
--save_every_n_epochs=$save_every_n_epochs `
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
--save_half
|
--save_precision="fp16"
|
||||||
|
@ -4,7 +4,9 @@
|
|||||||
# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
|
# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
|
||||||
# enable reg images in fine-tuning, add dataset_repeats option
|
# enable reg images in fine-tuning, add dataset_repeats option
|
||||||
# v8: supports Diffusers 0.7.2
|
# v8: supports Diffusers 0.7.2
|
||||||
|
# v9: add bucketing option
|
||||||
|
|
||||||
|
import time
|
||||||
from torch.autograd.function import Function
|
from torch.autograd.function import Function
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
@ -56,13 +58,40 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2
|
|||||||
|
|
||||||
# checkpointファイル名
|
# checkpointファイル名
|
||||||
LAST_CHECKPOINT_NAME = "last.ckpt"
|
LAST_CHECKPOINT_NAME = "last.ckpt"
|
||||||
|
LAST_STATE_NAME = "last-state"
|
||||||
EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
|
EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
|
||||||
|
EPOCH_STATE_NAME = "epoch-{:06d}-state"
|
||||||
|
|
||||||
|
|
||||||
|
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
||||||
|
max_width, max_height = max_reso
|
||||||
|
max_area = (max_width // divisible) * (max_height // divisible)
|
||||||
|
|
||||||
|
resos = set()
|
||||||
|
|
||||||
|
size = int(math.sqrt(max_area)) * divisible
|
||||||
|
resos.add((size, size))
|
||||||
|
|
||||||
|
size = min_size
|
||||||
|
while size <= max_size:
|
||||||
|
width = size
|
||||||
|
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
||||||
|
resos.add((width, height))
|
||||||
|
resos.add((height, width))
|
||||||
|
size += divisible
|
||||||
|
|
||||||
|
resos = list(resos)
|
||||||
|
resos.sort()
|
||||||
|
|
||||||
|
aspect_ratios = [w / h for w, h in resos]
|
||||||
|
return resos, aspect_ratios
|
||||||
|
|
||||||
|
|
||||||
class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
|
def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
self.fine_tuning = fine_tuning
|
self.fine_tuning = fine_tuning
|
||||||
self.train_img_path_captions = train_img_path_captions
|
self.train_img_path_captions = train_img_path_captions
|
||||||
self.reg_img_path_captions = reg_img_path_captions
|
self.reg_img_path_captions = reg_img_path_captions
|
||||||
@ -76,6 +105,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
self.shuffle_caption = shuffle_caption
|
self.shuffle_caption = shuffle_caption
|
||||||
self.disable_padding = disable_padding
|
self.disable_padding = disable_padding
|
||||||
self.latents_cache = None
|
self.latents_cache = None
|
||||||
|
self.enable_bucket = False
|
||||||
|
|
||||||
# augmentation
|
# augmentation
|
||||||
flip_p = 0.5 if flip_aug else 0.0
|
flip_p = 0.5 if flip_aug else 0.0
|
||||||
@ -102,12 +132,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.enable_reg_images = self.num_reg_images > 0
|
self.enable_reg_images = self.num_reg_images > 0
|
||||||
|
|
||||||
if not self.enable_reg_images:
|
if self.enable_reg_images and self.num_train_images < self.num_reg_images:
|
||||||
self._length = self.num_train_images
|
|
||||||
else:
|
|
||||||
# 学習データの倍として、奇数ならtrain
|
|
||||||
self._length = self.num_train_images * 2
|
|
||||||
if self._length // 2 < self.num_reg_images:
|
|
||||||
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
||||||
|
|
||||||
self.image_transforms = transforms.Compose(
|
self.image_transforms = transforms.Compose(
|
||||||
@ -117,6 +142,132 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
||||||
|
def make_buckets_with_caching(self, enable_bucket, vae):
|
||||||
|
self.enable_bucket = enable_bucket
|
||||||
|
|
||||||
|
cache_latents = vae is not None
|
||||||
|
if cache_latents:
|
||||||
|
if enable_bucket:
|
||||||
|
print("cache latents with bucketing")
|
||||||
|
else:
|
||||||
|
print("cache latents")
|
||||||
|
else:
|
||||||
|
if enable_bucket:
|
||||||
|
print("make buckets")
|
||||||
|
else:
|
||||||
|
print("prepare dataset")
|
||||||
|
|
||||||
|
# bucketingを用意する
|
||||||
|
if enable_bucket:
|
||||||
|
bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height))
|
||||||
|
else:
|
||||||
|
# bucketはひとつだけ、すべての画像は同じ解像度
|
||||||
|
bucket_resos = [(self.width, self.height)]
|
||||||
|
bucket_aspect_ratios = [self.width / self.height]
|
||||||
|
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
||||||
|
|
||||||
|
# 画像の解像度、latentをあらかじめ取得する
|
||||||
|
img_ar_errors = []
|
||||||
|
self.size_lat_cache = {}
|
||||||
|
for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions):
|
||||||
|
if image_path in self.size_lat_cache:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image = self.load_image(image_path)[0]
|
||||||
|
image_height, image_width = image.shape[0:2]
|
||||||
|
|
||||||
|
if not enable_bucket:
|
||||||
|
# assert image_width == self.width and image_height == self.height, \
|
||||||
|
# f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}"
|
||||||
|
reso = (self.width, self.height)
|
||||||
|
else:
|
||||||
|
# bucketを決める
|
||||||
|
aspect_ratio = image_width / image_height
|
||||||
|
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||||
|
bucket_id = np.abs(ar_errors).argmin()
|
||||||
|
reso = bucket_resos[bucket_id]
|
||||||
|
ar_error = ar_errors[bucket_id]
|
||||||
|
img_ar_errors.append(ar_error)
|
||||||
|
|
||||||
|
if cache_latents:
|
||||||
|
image = self.resize_and_trim(image, reso)
|
||||||
|
|
||||||
|
# latentを取得する
|
||||||
|
if cache_latents:
|
||||||
|
img_tensor = self.image_transforms(image)
|
||||||
|
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
||||||
|
latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
||||||
|
else:
|
||||||
|
latents = None
|
||||||
|
|
||||||
|
self.size_lat_cache[image_path] = (reso, latents)
|
||||||
|
|
||||||
|
# 画像をbucketに分割する
|
||||||
|
self.buckets = [[] for _ in range(len(bucket_resos))]
|
||||||
|
reso_to_index = {}
|
||||||
|
for i, reso in enumerate(bucket_resos):
|
||||||
|
reso_to_index[reso] = i
|
||||||
|
|
||||||
|
def split_to_buckets(is_reg, img_path_captions):
|
||||||
|
for image_path, caption in img_path_captions:
|
||||||
|
reso, _ = self.size_lat_cache[image_path]
|
||||||
|
bucket_index = reso_to_index[reso]
|
||||||
|
self.buckets[bucket_index].append((is_reg, image_path, caption))
|
||||||
|
|
||||||
|
split_to_buckets(False, self.train_img_path_captions)
|
||||||
|
|
||||||
|
if self.enable_reg_images:
|
||||||
|
l = []
|
||||||
|
while len(l) < len(self.train_img_path_captions):
|
||||||
|
l += self.reg_img_path_captions
|
||||||
|
l = l[:len(self.train_img_path_captions)]
|
||||||
|
split_to_buckets(True, l)
|
||||||
|
|
||||||
|
if enable_bucket:
|
||||||
|
print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数")
|
||||||
|
for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)):
|
||||||
|
print(f"bucket {i}: resolution {reso}, count: {len(imgs)}")
|
||||||
|
img_ar_errors = np.array(img_ar_errors)
|
||||||
|
print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}")
|
||||||
|
|
||||||
|
# 参照用indexを作る
|
||||||
|
self.buckets_indices = []
|
||||||
|
for bucket_index, bucket in enumerate(self.buckets):
|
||||||
|
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||||
|
for batch_index in range(batch_count):
|
||||||
|
self.buckets_indices.append((bucket_index, batch_index))
|
||||||
|
|
||||||
|
self.shuffle_buckets()
|
||||||
|
self._length = len(self.buckets_indices)
|
||||||
|
|
||||||
|
# どのサイズにリサイズするか→トリミングする方向で
|
||||||
|
def resize_and_trim(self, image, reso):
|
||||||
|
image_height, image_width = image.shape[0:2]
|
||||||
|
ar_img = image_width / image_height
|
||||||
|
ar_reso = reso[0] / reso[1]
|
||||||
|
if ar_img > ar_reso: # 横が長い→縦を合わせる
|
||||||
|
scale = reso[1] / image_height
|
||||||
|
else:
|
||||||
|
scale = reso[0] / image_width
|
||||||
|
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
|
||||||
|
|
||||||
|
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||||
|
if resized_size[0] > reso[0]:
|
||||||
|
trim_size = resized_size[0] - reso[0]
|
||||||
|
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||||
|
elif resized_size[1] > reso[1]:
|
||||||
|
trim_size = resized_size[1] - reso[1]
|
||||||
|
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||||
|
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \
|
||||||
|
f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||||
|
return image
|
||||||
|
|
||||||
|
def shuffle_buckets(self):
|
||||||
|
random.shuffle(self.buckets_indices)
|
||||||
|
for bucket in self.buckets:
|
||||||
|
random.shuffle(bucket)
|
||||||
|
|
||||||
def load_image(self, image_path):
|
def load_image(self, image_path):
|
||||||
image = Image.open(image_path)
|
image = Image.open(image_path)
|
||||||
if not image.mode == "RGB":
|
if not image.mode == "RGB":
|
||||||
@ -184,41 +335,32 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._length
|
return self._length
|
||||||
|
|
||||||
def set_cached_latents(self, image_path, latents):
|
def __getitem__(self, index):
|
||||||
if self.latents_cache is None:
|
if index == 0:
|
||||||
self.latents_cache = {}
|
self.shuffle_buckets()
|
||||||
self.latents_cache[image_path] = latents
|
|
||||||
|
|
||||||
def __getitem__(self, index_arg):
|
bucket = self.buckets[self.buckets_indices[index][0]]
|
||||||
example = {}
|
image_index = self.buckets_indices[index][1] * self.batch_size
|
||||||
|
|
||||||
if not self.enable_reg_images:
|
latents_list = []
|
||||||
index = index_arg
|
images = []
|
||||||
img_path_captions = self.train_img_path_captions
|
captions = []
|
||||||
reg = False
|
loss_weights = []
|
||||||
else:
|
|
||||||
# 偶数ならtrain、奇数ならregを返す
|
|
||||||
if index_arg % 2 == 0:
|
|
||||||
img_path_captions = self.train_img_path_captions
|
|
||||||
reg = False
|
|
||||||
else:
|
|
||||||
img_path_captions = self.reg_img_path_captions
|
|
||||||
reg = True
|
|
||||||
index = index_arg // 2
|
|
||||||
example['loss_weight'] = 1.0 if (not reg or self.fine_tuning) else self.prior_loss_weight
|
|
||||||
|
|
||||||
index = index % len(img_path_captions)
|
for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]:
|
||||||
image_path, caption = img_path_captions[index]
|
loss_weights.append(1.0 if is_reg else self.prior_loss_weight)
|
||||||
example['image_path'] = image_path
|
|
||||||
|
|
||||||
# image/latentsを処理する
|
# image/latentsを処理する
|
||||||
if self.latents_cache is not None and image_path in self.latents_cache:
|
reso, latents = self.size_lat_cache[image_path]
|
||||||
# latentsはキャッシュ済み
|
|
||||||
example['latents'] = self.latents_cache[image_path]
|
if latents is None:
|
||||||
else:
|
|
||||||
# 画像を読み込み必要ならcropする
|
# 画像を読み込み必要ならcropする
|
||||||
img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
|
img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
|
||||||
im_h, im_w = img.shape[0:2]
|
im_h, im_w = img.shape[0:2]
|
||||||
|
|
||||||
|
if self.enable_bucket:
|
||||||
|
img = self.resize_and_trim(img, reso)
|
||||||
|
else:
|
||||||
if face_cx > 0: # 顔位置情報あり
|
if face_cx > 0: # 顔位置情報あり
|
||||||
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
||||||
elif im_h > self.height or im_w > self.width:
|
elif im_h > self.height or im_w > self.width:
|
||||||
@ -231,36 +373,47 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
img = img[:, p:p + self.width]
|
img = img[:, p:p + self.width]
|
||||||
|
|
||||||
im_h, im_w = img.shape[0:2]
|
im_h, im_w = img.shape[0:2]
|
||||||
assert im_h == self.height and im_w == self.width, f"image too small / 画像サイズが小さいようです: {image_path}"
|
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}"
|
||||||
|
|
||||||
# augmentation
|
# augmentation
|
||||||
if self.aug is not None:
|
if self.aug is not None:
|
||||||
img = self.aug(image=img)['image']
|
img = self.aug(image=img)['image']
|
||||||
|
|
||||||
example['image'] = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||||
|
else:
|
||||||
|
image = None
|
||||||
|
|
||||||
|
images.append(image)
|
||||||
|
latents_list.append(latents)
|
||||||
|
|
||||||
# captionを処理する
|
# captionを処理する
|
||||||
if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする
|
if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする
|
||||||
tokens = caption.strip().split(",")
|
tokens = caption.strip().split(",")
|
||||||
random.shuffle(tokens)
|
random.shuffle(tokens)
|
||||||
caption = ",".join(tokens).strip()
|
caption = ",".join(tokens).strip()
|
||||||
|
captions.append(caption)
|
||||||
|
|
||||||
input_ids = self.tokenizer(caption, padding="do_not_pad", truncation=True,
|
# input_idsをpadしてTensor変換
|
||||||
max_length=self.tokenizer.model_max_length).input_ids
|
|
||||||
|
|
||||||
# padしてTensor変換
|
|
||||||
if self.disable_padding:
|
if self.disable_padding:
|
||||||
# paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?)
|
# paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?)
|
||||||
input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||||
else:
|
else:
|
||||||
# paddingする
|
# paddingする
|
||||||
input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding='max_length', max_length=self.tokenizer.model_max_length,
|
input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
|
||||||
return_tensors='pt').input_ids
|
|
||||||
|
|
||||||
|
example = {}
|
||||||
|
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
||||||
example['input_ids'] = input_ids
|
example['input_ids'] = input_ids
|
||||||
|
if images[0] is not None:
|
||||||
|
images = torch.stack(images)
|
||||||
|
images = images.to(memory_format=torch.contiguous_format).float()
|
||||||
|
else:
|
||||||
|
images = None
|
||||||
|
example['images'] = images
|
||||||
|
example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
|
||||||
if self.debug_dataset:
|
if self.debug_dataset:
|
||||||
example['caption'] = caption
|
example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]]
|
||||||
|
example['captions'] = captions
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
|
||||||
@ -916,7 +1069,7 @@ def load_models_from_stable_diffusion_checkpoint(ckpt_path):
|
|||||||
return text_model, vae, unet
|
return text_model, vae, unet
|
||||||
|
|
||||||
|
|
||||||
def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps):
|
def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
|
||||||
# VAEがメモリ上にないので、もう一度VAEを含めて読み込む
|
# VAEがメモリ上にないので、もう一度VAEを含めて読み込む
|
||||||
checkpoint = load_checkpoint_with_conversion(ckpt_path)
|
checkpoint = load_checkpoint_with_conversion(ckpt_path)
|
||||||
state_dict = checkpoint["state_dict"]
|
state_dict = checkpoint["state_dict"]
|
||||||
@ -926,6 +1079,8 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
|
|||||||
for k, v in unet_state_dict.items():
|
for k, v in unet_state_dict.items():
|
||||||
key = "model.diffusion_model." + k
|
key = "model.diffusion_model." + k
|
||||||
assert key in state_dict, f"Illegal key in save SD: {key}"
|
assert key in state_dict, f"Illegal key in save SD: {key}"
|
||||||
|
if save_dtype is not None:
|
||||||
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
state_dict[key] = v
|
state_dict[key] = v
|
||||||
|
|
||||||
# Convert the text encoder model
|
# Convert the text encoder model
|
||||||
@ -933,6 +1088,8 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
|
|||||||
for k, v in text_enc_dict.items():
|
for k, v in text_enc_dict.items():
|
||||||
key = "cond_stage_model.transformer." + k
|
key = "cond_stage_model.transformer." + k
|
||||||
assert key in state_dict, f"Illegal key in save SD: {key}"
|
assert key in state_dict, f"Illegal key in save SD: {key}"
|
||||||
|
if save_dtype is not None:
|
||||||
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
state_dict[key] = v
|
state_dict[key] = v
|
||||||
|
|
||||||
# Put together new checkpoint
|
# Put together new checkpoint
|
||||||
@ -951,24 +1108,7 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
|
|||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
input_ids = [e['input_ids'] for e in examples]
|
return examples[0]
|
||||||
input_ids = torch.stack(input_ids)
|
|
||||||
|
|
||||||
if 'latents' in examples[0]:
|
|
||||||
pixel_values = None
|
|
||||||
latents = [e['latents'] for e in examples]
|
|
||||||
latents = torch.stack(latents)
|
|
||||||
else:
|
|
||||||
pixel_values = [e['image'] for e in examples]
|
|
||||||
pixel_values = torch.stack(pixel_values)
|
|
||||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
|
||||||
latents = None
|
|
||||||
|
|
||||||
loss_weights = [e['loss_weight'] for e in examples]
|
|
||||||
loss_weights = torch.FloatTensor(loss_weights)
|
|
||||||
|
|
||||||
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "latents": latents, "loss_weights": loss_weights}
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
@ -998,19 +1138,22 @@ def train(args):
|
|||||||
try:
|
try:
|
||||||
n_repeats = int(tokens[0])
|
n_repeats = int(tokens[0])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
|
# print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
|
||||||
raise e
|
# raise e
|
||||||
|
return 0, []
|
||||||
|
|
||||||
caption = '_'.join(tokens[1:])
|
caption = '_'.join(tokens[1:])
|
||||||
|
|
||||||
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg"))
|
print(f"found directory {n_repeats}_{caption}")
|
||||||
|
|
||||||
|
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + glob.glob(os.path.join(dir, "*.webp"))
|
||||||
return n_repeats, [(ip, caption) for ip in img_paths]
|
return n_repeats, [(ip, caption) for ip in img_paths]
|
||||||
|
|
||||||
print("prepare train images.")
|
print("prepare train images.")
|
||||||
train_img_path_captions = []
|
train_img_path_captions = []
|
||||||
|
|
||||||
if fine_tuning:
|
if fine_tuning:
|
||||||
img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg"))
|
img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
for img_path in tqdm(img_paths):
|
for img_path in tqdm(img_paths):
|
||||||
# captionの候補ファイル名を作る
|
# captionの候補ファイル名を作る
|
||||||
base_name = os.path.splitext(img_path)[0]
|
base_name = os.path.splitext(img_path)[0]
|
||||||
@ -1042,7 +1185,7 @@ def train(args):
|
|||||||
n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
|
n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
|
||||||
for _ in range(n_repeats):
|
for _ in range(n_repeats):
|
||||||
train_img_path_captions.extend(img_caps)
|
train_img_path_captions.extend(img_caps)
|
||||||
print(f"{len(train_img_path_captions)} train images.")
|
print(f"{len(train_img_path_captions)} train images with repeating.")
|
||||||
|
|
||||||
reg_img_path_captions = []
|
reg_img_path_captions = []
|
||||||
if args.reg_data_dir:
|
if args.reg_data_dir:
|
||||||
@ -1054,11 +1197,6 @@ def train(args):
|
|||||||
reg_img_path_captions.extend(img_caps)
|
reg_img_path_captions.extend(img_caps)
|
||||||
print(f"{len(reg_img_path_captions)} reg images.")
|
print(f"{len(reg_img_path_captions)} reg images.")
|
||||||
|
|
||||||
if args.debug_dataset:
|
|
||||||
# デバッグ時はshuffleして実際のデータセット使用時に近づける(学習時はdata loaderでshuffleする)
|
|
||||||
random.shuffle(train_img_path_captions)
|
|
||||||
random.shuffle(reg_img_path_captions)
|
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
resolution = tuple([int(r) for r in args.resolution.split(',')])
|
resolution = tuple([int(r) for r in args.resolution.split(',')])
|
||||||
if len(resolution) == 1:
|
if len(resolution) == 1:
|
||||||
@ -1078,29 +1216,40 @@ def train(args):
|
|||||||
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
||||||
|
|
||||||
print("prepare dataset")
|
print("prepare dataset")
|
||||||
train_dataset = DreamBoothOrFineTuningDataset(fine_tuning, train_img_path_captions,
|
train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution,
|
||||||
reg_img_path_captions, tokenizer, resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.shuffle_caption, args.no_token_padding, args.debug_dataset)
|
args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop,
|
||||||
|
args.shuffle_caption, args.no_token_padding, args.debug_dataset)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
|
train_dataset.make_buckets_with_caching(args.enable_bucket, None) # デバッグ用にcacheなしで作る
|
||||||
|
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||||
print("Escape for exit. / Escキーで中断、終了します")
|
print("Escape for exit. / Escキーで中断、終了します")
|
||||||
for example in train_dataset:
|
for example in train_dataset:
|
||||||
im = example['image']
|
for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']):
|
||||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||||
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
||||||
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
||||||
print(f'caption: "{example["caption"]}", loss weight: {example["loss_weight"]}')
|
print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}')
|
||||||
cv2.imshow("img", im)
|
cv2.imshow("img", im)
|
||||||
k = cv2.waitKey()
|
k = cv2.waitKey()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
if k == 27:
|
if k == 27:
|
||||||
break
|
break
|
||||||
|
if k == 27:
|
||||||
|
break
|
||||||
return
|
return
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
# gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
|
# gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
|
||||||
print("prepare accelerator")
|
print("prepare accelerator")
|
||||||
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision)
|
if args.logging_dir is None:
|
||||||
|
log_with = None
|
||||||
|
logging_dir = None
|
||||||
|
else:
|
||||||
|
log_with = "tensorboard"
|
||||||
|
logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||||
|
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision,
|
||||||
|
log_with=log_with, logging_dir=logging_dir)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
@ -1122,28 +1271,24 @@ def train(args):
|
|||||||
elif args.mixed_precision == "bf16":
|
elif args.mixed_precision == "bf16":
|
||||||
weight_dtype = torch.bfloat16
|
weight_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
save_dtype = None
|
||||||
|
if args.save_precision == "fp16":
|
||||||
|
save_dtype = torch.float16
|
||||||
|
elif args.save_precision == "bf16":
|
||||||
|
save_dtype = torch.bfloat16
|
||||||
|
elif args.save_precision == "float":
|
||||||
|
save_dtype = torch.float32
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
# latentをcacheする→新しいDatasetを作るとcaptionのshuffleが効かないので元のDatasetにcacheを持つ(cascadeする手もあるが)
|
|
||||||
print("caching latents.")
|
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
for i in tqdm(range(len(train_dataset))):
|
|
||||||
example = train_dataset[i]
|
|
||||||
if 'latents' not in example:
|
|
||||||
image_path = example['image_path']
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pixel_values = example["image"].unsqueeze(0).to(device=accelerator.device, dtype=weight_dtype)
|
train_dataset.make_buckets_with_caching(args.enable_bucket, vae)
|
||||||
latents = vae.encode(pixel_values).latent_dist.sample().squeeze(0).to("cpu")
|
|
||||||
train_dataset.set_cached_latents(image_path, latents)
|
|
||||||
# assertion
|
|
||||||
for i in range(len(train_dataset)):
|
|
||||||
assert 'latents' in train_dataset[i], "internal error: latents not cached"
|
|
||||||
|
|
||||||
del vae
|
del vae
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
else:
|
else:
|
||||||
|
train_dataset.make_buckets_with_caching(args.enable_bucket, None)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
@ -1173,7 +1318,7 @@ def train(args):
|
|||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=n_workers)
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
|
lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
|
||||||
@ -1185,6 +1330,11 @@ def train(args):
|
|||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# resumeする
|
||||||
|
if args.resume is not None:
|
||||||
|
print(f"resume training from state: {args.resume}")
|
||||||
|
accelerator.load_state(args.resume)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
|
num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
|
||||||
|
|
||||||
@ -1193,7 +1343,7 @@ def train(args):
|
|||||||
print("running training / 学習開始")
|
print("running training / 学習開始")
|
||||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
||||||
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
||||||
print(f" num examples / サンプル数: {len(train_dataset)}")
|
print(f" num examples / サンプル数: {train_dataset.num_train_images * 2}")
|
||||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
@ -1222,7 +1372,7 @@ def train(args):
|
|||||||
if cache_latents:
|
if cache_latents:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
else:
|
else:
|
||||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
|
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
@ -1271,15 +1421,22 @@ def train(args):
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
loss_total += current_loss
|
loss_total += current_loss
|
||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
# accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
||||||
|
accelerator.log(logs, step=epoch+1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if use_stable_diffusion_format and args.save_every_n_epochs is not None:
|
if use_stable_diffusion_format and args.save_every_n_epochs is not None:
|
||||||
@ -1288,7 +1445,11 @@ def train(args):
|
|||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
|
ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
|
||||||
save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
|
save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
|
||||||
args.pretrained_model_name_or_path, epoch + 1, global_step)
|
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
|
||||||
|
|
||||||
|
if args.save_state:
|
||||||
|
print("saving state.")
|
||||||
|
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
@ -1296,6 +1457,11 @@ def train(args):
|
|||||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
|
if args.save_state:
|
||||||
|
print("saving last state.")
|
||||||
|
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
|
||||||
|
|
||||||
del accelerator # この後メモリを使うのでこれは消す
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
@ -1303,7 +1469,8 @@ def train(args):
|
|||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
|
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
|
||||||
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
||||||
save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step)
|
save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet,
|
||||||
|
args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
|
||||||
else:
|
else:
|
||||||
# Create the pipeline using using the trained modules and save it.
|
# Create the pipeline using using the trained modules and save it.
|
||||||
print(f"save trained model as Diffusers to {args.output_dir}")
|
print(f"save trained model as Diffusers to {args.output_dir}")
|
||||||
@ -1589,6 +1756,10 @@ if __name__ == '__main__':
|
|||||||
help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
|
help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
|
||||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||||
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
|
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
|
||||||
|
parser.add_argument("--save_state", action="store_true",
|
||||||
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||||
|
parser.add_argument("--resume", type=str, default=None,
|
||||||
|
help="saved state to resume training / 学習再開するモデルのstate")
|
||||||
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
|
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
|
||||||
parser.add_argument("--no_token_padding", action="store_true",
|
parser.add_argument("--no_token_padding", action="store_true",
|
||||||
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
||||||
@ -1612,6 +1783,8 @@ if __name__ == '__main__':
|
|||||||
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||||
parser.add_argument("--cache_latents", action="store_true",
|
parser.add_argument("--cache_latents", action="store_true",
|
||||||
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
||||||
|
parser.add_argument("--enable_bucket", action="store_true",
|
||||||
|
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
|
||||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||||
@ -1619,8 +1792,12 @@ if __name__ == '__main__':
|
|||||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||||
|
parser.add_argument("--save_precision", type=str, default=None,
|
||||||
|
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
||||||
parser.add_argument("--clip_skip", type=int, default=None,
|
parser.add_argument("--clip_skip", type=int, default=None,
|
||||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||||
|
parser.add_argument("--logging_dir", type=str, default=None,
|
||||||
|
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
Loading…
Reference in New Issue
Block a user