Update train_db_fixed to v9
This commit is contained in:
parent
f56340d53e
commit
0e8b993def
@ -113,9 +113,10 @@ accelerate launch --num_cpu_threads_per_process 6 train_db_fixed-ber.py `
|
|||||||
--cache_latents `
|
--cache_latents `
|
||||||
--save_every_n_epochs=1 `
|
--save_every_n_epochs=1 `
|
||||||
--fine_tuning `
|
--fine_tuning `
|
||||||
|
--enable_bucket `
|
||||||
--dataset_repeats=200 `
|
--dataset_repeats=200 `
|
||||||
--seed=23 `
|
--seed=23 `
|
||||||
--save_half
|
---save_precision="fp16"
|
||||||
```
|
```
|
||||||
|
|
||||||
Refer to this url for more details about finetuning: https://note.com/kohya_ss/n/n1269f1e1a54e
|
Refer to this url for more details about finetuning: https://note.com/kohya_ss/n/n1269f1e1a54e
|
||||||
@ -129,3 +130,8 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n
|
|||||||
- Added option to learn Text Encoder --train_text_encoder.
|
- Added option to learn Text Encoder --train_text_encoder.
|
||||||
- The data format of checkpoint at the time of saving can be specified with the --save_precision option. You can choose float, fp16, and bf16.
|
- The data format of checkpoint at the time of saving can be specified with the --save_precision option. You can choose float, fp16, and bf16.
|
||||||
- Added a --save_state option to save the learning state (optimizer, etc.) in the middle. It can be resumed with the --resume option.
|
- Added a --save_state option to save the learning state (optimizer, etc.) in the middle. It can be resumed with the --resume option.
|
||||||
|
* 11/18 (v9):
|
||||||
|
- Added support for Aspect Ratio Bucketing (enable_bucket option). (--enable_bucket)
|
||||||
|
- Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision)
|
||||||
|
- Added support for saving learning state (--save_state, --resume)
|
||||||
|
- Added support for logging (--logging_dir)
|
||||||
|
@ -2,9 +2,12 @@
|
|||||||
#
|
#
|
||||||
# Usefull to create base caption that will be augmented on a per image basis
|
# Usefull to create base caption that will be augmented on a per image basis
|
||||||
|
|
||||||
$folder = "D:\dreambooth\train_sylvia_ritter\raw_data\all-images\"
|
$folder = "D:\some\folder\location\"
|
||||||
$file_pattern="*.*"
|
$file_pattern="*.*"
|
||||||
$text_fir_file="a digital painting of xxx, by silvery trait"
|
$caption_text="some caption text"
|
||||||
|
|
||||||
$files = Get-ChildItem $folder$file_pattern
|
$files = Get-ChildItem $folder$file_pattern -Include *.png,*.jpg,*.webp -File
|
||||||
foreach ($file in $files) {New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $text_fir_file}
|
foreach ($file in $files)
|
||||||
|
{
|
||||||
|
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text
|
||||||
|
}
|
20
examples/caption_subfolders.ps1
Normal file
20
examples/caption_subfolders.ps1
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# This powershell script will create a text file for each files in the folder
|
||||||
|
#
|
||||||
|
# Usefull to create base caption that will be augmented on a per image basis
|
||||||
|
|
||||||
|
$folder = "D:\test\t2\"
|
||||||
|
$file_pattern="*.*"
|
||||||
|
$text_fir_file="bigeyes style"
|
||||||
|
|
||||||
|
foreach ($file in Get-ChildItem $folder\$file_pattern -File)
|
||||||
|
{
|
||||||
|
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $text_fir_file
|
||||||
|
}
|
||||||
|
|
||||||
|
foreach($directory in Get-ChildItem -path $folder -Directory)
|
||||||
|
{
|
||||||
|
foreach ($file in Get-ChildItem $folder\$directory\$file_pattern)
|
||||||
|
{
|
||||||
|
New-Item -ItemType file -Path $folder\$directory -Name "$($file.BaseName).txt" -Value $text_fir_file
|
||||||
|
}
|
||||||
|
}
|
87
examples/kohya-1-folders.ps1
Normal file
87
examples/kohya-1-folders.ps1
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# This powershell script will create a model using the fine tuning dreambooth method. It will require landscape,
|
||||||
|
# portrait and square images.
|
||||||
|
#
|
||||||
|
# Adjust the script to your own needs
|
||||||
|
|
||||||
|
# Sylvia Ritter
|
||||||
|
# variable values
|
||||||
|
$pretrained_model_name_or_path = "D:\models\v1-5-pruned-mse-vae.ckpt"
|
||||||
|
$data_dir = "D:\test\squat"
|
||||||
|
$train_dir = "D:\test\"
|
||||||
|
$resolution = "512,512"
|
||||||
|
|
||||||
|
$image_num = Get-ChildItem $data_dir -Recurse -File -Include *.png | Measure-Object | %{$_.Count}
|
||||||
|
|
||||||
|
Write-Output "image_num: $image_num"
|
||||||
|
|
||||||
|
$learning_rate = 1e-6
|
||||||
|
$dataset_repeats = 40
|
||||||
|
$train_batch_size = 8
|
||||||
|
$epoch = 1
|
||||||
|
$save_every_n_epochs=1
|
||||||
|
$mixed_precision="fp16"
|
||||||
|
$num_cpu_threads_per_process=6
|
||||||
|
|
||||||
|
# You should not have to change values past this point
|
||||||
|
|
||||||
|
$output_dir = $train_dir + "\model"
|
||||||
|
$repeats = $image_num * $dataset_repeats
|
||||||
|
$mts = [Math]::Ceiling($repeats / $train_batch_size * $epoch)
|
||||||
|
|
||||||
|
Write-Output "Repeats: $repeats"
|
||||||
|
|
||||||
|
.\venv\Scripts\activate
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed-ber.py `
|
||||||
|
--pretrained_model_name_or_path=$pretrained_model_name_or_path `
|
||||||
|
--train_data_dir=$data_dir `
|
||||||
|
--output_dir=$output_dir `
|
||||||
|
--resolution=$resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$mts `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$dataset_repeats `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
# 2nd pass at half the dataset repeat value
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$output_dir"\last.ckpt" `
|
||||||
|
--train_data_dir=$data_dir `
|
||||||
|
--output_dir=$output_dir"2" `
|
||||||
|
--resolution=$resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$([Math]::Ceiling($mts/2)) `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed-ber.py `
|
||||||
|
--pretrained_model_name_or_path=$output_dir"\last.ckpt" `
|
||||||
|
--train_data_dir=$data_dir `
|
||||||
|
--output_dir=$output_dir"2" `
|
||||||
|
--resolution=$resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$mts `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$dataset_repeats `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
154
examples/kohya-3-folders.ps1
Normal file
154
examples/kohya-3-folders.ps1
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
# This powershell script will create a model using the fine tuning dreambooth method. It will require landscape,
|
||||||
|
# portrait and square images.
|
||||||
|
#
|
||||||
|
# Adjust the script to your own needs
|
||||||
|
|
||||||
|
# Sylvia Ritter
|
||||||
|
# variable values
|
||||||
|
$pretrained_model_name_or_path = "D:\models\v1-5-pruned-mse-vae.ckpt"
|
||||||
|
$train_dir = "D:\dreambooth\train_sylvia_ritter\raw_data"
|
||||||
|
|
||||||
|
$landscape_image_num = 4
|
||||||
|
$portrait_image_num = 25
|
||||||
|
$square_image_num = 2
|
||||||
|
|
||||||
|
$learning_rate = 1e-6
|
||||||
|
$dataset_repeats = 120
|
||||||
|
$train_batch_size = 4
|
||||||
|
$epoch = 1
|
||||||
|
$save_every_n_epochs=1
|
||||||
|
$mixed_precision="fp16"
|
||||||
|
$num_cpu_threads_per_process=6
|
||||||
|
|
||||||
|
$landscape_folder_name = "landscape-pp"
|
||||||
|
$landscape_resolution = "832,512"
|
||||||
|
$portrait_folder_name = "portrait-pp"
|
||||||
|
$portrait_resolution = "448,896"
|
||||||
|
$square_folder_name = "square-pp"
|
||||||
|
$square_resolution = "512,512"
|
||||||
|
|
||||||
|
# You should not have to change values past this point
|
||||||
|
|
||||||
|
$landscape_data_dir = $train_dir + "\" + $landscape_folder_name
|
||||||
|
$portrait_data_dir = $train_dir + "\" + $portrait_folder_name
|
||||||
|
$square_data_dir = $train_dir + "\" + $square_folder_name
|
||||||
|
$landscape_output_dir = $train_dir + "\model-l"
|
||||||
|
$portrait_output_dir = $train_dir + "\model-lp"
|
||||||
|
$square_output_dir = $train_dir + "\model-lps"
|
||||||
|
|
||||||
|
$landscape_repeats = $landscape_image_num * $dataset_repeats
|
||||||
|
$portrait_repeats = $portrait_image_num * $dataset_repeats
|
||||||
|
$square_repeats = $square_image_num * $dataset_repeats
|
||||||
|
|
||||||
|
$landscape_mts = [Math]::Ceiling($landscape_repeats / $train_batch_size * $epoch)
|
||||||
|
$portrait_mts = [Math]::Ceiling($portrait_repeats / $train_batch_size * $epoch)
|
||||||
|
$square_mts = [Math]::Ceiling($square_repeats / $train_batch_size * $epoch)
|
||||||
|
|
||||||
|
# Write-Output $landscape_repeats
|
||||||
|
|
||||||
|
.\venv\Scripts\activate
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$pretrained_model_name_or_path `
|
||||||
|
--train_data_dir=$landscape_data_dir `
|
||||||
|
--output_dir=$landscape_output_dir `
|
||||||
|
--resolution=$landscape_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$landscape_mts `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$dataset_repeats `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$landscape_output_dir"\last.ckpt" `
|
||||||
|
--train_data_dir=$portrait_data_dir `
|
||||||
|
--output_dir=$portrait_output_dir `
|
||||||
|
--resolution=$portrait_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$portrait_mts `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$dataset_repeats `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$portrait_output_dir"\last.ckpt" `
|
||||||
|
--train_data_dir=$square_data_dir `
|
||||||
|
--output_dir=$square_output_dir `
|
||||||
|
--resolution=$square_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$square_mts `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$dataset_repeats `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
# 2nd pass at half the dataset repeat value
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$square_output_dir"\last.ckpt" `
|
||||||
|
--train_data_dir=$landscape_data_dir `
|
||||||
|
--output_dir=$landscape_output_dir"2" `
|
||||||
|
--resolution=$landscape_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$([Math]::Ceiling($landscape_mts/2)) `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$landscape_output_dir"2\last.ckpt" `
|
||||||
|
--train_data_dir=$portrait_data_dir `
|
||||||
|
--output_dir=$portrait_output_dir"2" `
|
||||||
|
--resolution=$portrait_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$([Math]::Ceiling($portrait_mts/2)) `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
||||||
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
|
||||||
|
--pretrained_model_name_or_path=$portrait_output_dir"2\last.ckpt" `
|
||||||
|
--train_data_dir=$square_data_dir `
|
||||||
|
--output_dir=$square_output_dir"2" `
|
||||||
|
--resolution=$square_resolution `
|
||||||
|
--train_batch_size=$train_batch_size `
|
||||||
|
--learning_rate=$learning_rate `
|
||||||
|
--max_train_steps=$([Math]::Ceiling($square_mts/2)) `
|
||||||
|
--use_8bit_adam `
|
||||||
|
--xformers `
|
||||||
|
--mixed_precision=$mixed_precision `
|
||||||
|
--cache_latents `
|
||||||
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
|
--fine_tuning `
|
||||||
|
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
|
||||||
|
--save_precision="fp16"
|
||||||
|
|
@ -55,7 +55,7 @@ accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\
|
|||||||
--use_8bit_adam --xformers `
|
--use_8bit_adam --xformers `
|
||||||
--mixed_precision=$mixed_precision `
|
--mixed_precision=$mixed_precision `
|
||||||
--save_every_n_epochs=$save_every_n_epochs `
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
--save_half
|
--save_precision="fp16"
|
||||||
|
|
||||||
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\kohya_ss\diffusers_fine_tuning\fine_tune.py `
|
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\kohya_ss\diffusers_fine_tuning\fine_tune.py `
|
||||||
--pretrained_model_name_or_path=$train_dir"\fine_tuned\last.ckpt" `
|
--pretrained_model_name_or_path=$train_dir"\fine_tuned\last.ckpt" `
|
||||||
@ -69,4 +69,4 @@ accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\
|
|||||||
--use_8bit_adam --xformers `
|
--use_8bit_adam --xformers `
|
||||||
--mixed_precision=$mixed_precision `
|
--mixed_precision=$mixed_precision `
|
||||||
--save_every_n_epochs=$save_every_n_epochs `
|
--save_every_n_epochs=$save_every_n_epochs `
|
||||||
--save_half
|
--save_precision="fp16"
|
||||||
|
@ -4,7 +4,9 @@
|
|||||||
# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
|
# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
|
||||||
# enable reg images in fine-tuning, add dataset_repeats option
|
# enable reg images in fine-tuning, add dataset_repeats option
|
||||||
# v8: supports Diffusers 0.7.2
|
# v8: supports Diffusers 0.7.2
|
||||||
|
# v9: add bucketing option
|
||||||
|
|
||||||
|
import time
|
||||||
from torch.autograd.function import Function
|
from torch.autograd.function import Function
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
@ -56,13 +58,40 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2
|
|||||||
|
|
||||||
# checkpointファイル名
|
# checkpointファイル名
|
||||||
LAST_CHECKPOINT_NAME = "last.ckpt"
|
LAST_CHECKPOINT_NAME = "last.ckpt"
|
||||||
|
LAST_STATE_NAME = "last-state"
|
||||||
EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
|
EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
|
||||||
|
EPOCH_STATE_NAME = "epoch-{:06d}-state"
|
||||||
|
|
||||||
|
|
||||||
|
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
||||||
|
max_width, max_height = max_reso
|
||||||
|
max_area = (max_width // divisible) * (max_height // divisible)
|
||||||
|
|
||||||
|
resos = set()
|
||||||
|
|
||||||
|
size = int(math.sqrt(max_area)) * divisible
|
||||||
|
resos.add((size, size))
|
||||||
|
|
||||||
|
size = min_size
|
||||||
|
while size <= max_size:
|
||||||
|
width = size
|
||||||
|
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
||||||
|
resos.add((width, height))
|
||||||
|
resos.add((height, width))
|
||||||
|
size += divisible
|
||||||
|
|
||||||
|
resos = list(resos)
|
||||||
|
resos.sort()
|
||||||
|
|
||||||
|
aspect_ratios = [w / h for w, h in resos]
|
||||||
|
return resos, aspect_ratios
|
||||||
|
|
||||||
|
|
||||||
class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
|
def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
self.fine_tuning = fine_tuning
|
self.fine_tuning = fine_tuning
|
||||||
self.train_img_path_captions = train_img_path_captions
|
self.train_img_path_captions = train_img_path_captions
|
||||||
self.reg_img_path_captions = reg_img_path_captions
|
self.reg_img_path_captions = reg_img_path_captions
|
||||||
@ -76,6 +105,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
self.shuffle_caption = shuffle_caption
|
self.shuffle_caption = shuffle_caption
|
||||||
self.disable_padding = disable_padding
|
self.disable_padding = disable_padding
|
||||||
self.latents_cache = None
|
self.latents_cache = None
|
||||||
|
self.enable_bucket = False
|
||||||
|
|
||||||
# augmentation
|
# augmentation
|
||||||
flip_p = 0.5 if flip_aug else 0.0
|
flip_p = 0.5 if flip_aug else 0.0
|
||||||
@ -102,12 +132,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.enable_reg_images = self.num_reg_images > 0
|
self.enable_reg_images = self.num_reg_images > 0
|
||||||
|
|
||||||
if not self.enable_reg_images:
|
if self.enable_reg_images and self.num_train_images < self.num_reg_images:
|
||||||
self._length = self.num_train_images
|
|
||||||
else:
|
|
||||||
# 学習データの倍として、奇数ならtrain
|
|
||||||
self._length = self.num_train_images * 2
|
|
||||||
if self._length // 2 < self.num_reg_images:
|
|
||||||
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
||||||
|
|
||||||
self.image_transforms = transforms.Compose(
|
self.image_transforms = transforms.Compose(
|
||||||
@ -117,6 +142,132 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
||||||
|
def make_buckets_with_caching(self, enable_bucket, vae):
|
||||||
|
self.enable_bucket = enable_bucket
|
||||||
|
|
||||||
|
cache_latents = vae is not None
|
||||||
|
if cache_latents:
|
||||||
|
if enable_bucket:
|
||||||
|
print("cache latents with bucketing")
|
||||||
|
else:
|
||||||
|
print("cache latents")
|
||||||
|
else:
|
||||||
|
if enable_bucket:
|
||||||
|
print("make buckets")
|
||||||
|
else:
|
||||||
|
print("prepare dataset")
|
||||||
|
|
||||||
|
# bucketingを用意する
|
||||||
|
if enable_bucket:
|
||||||
|
bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height))
|
||||||
|
else:
|
||||||
|
# bucketはひとつだけ、すべての画像は同じ解像度
|
||||||
|
bucket_resos = [(self.width, self.height)]
|
||||||
|
bucket_aspect_ratios = [self.width / self.height]
|
||||||
|
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
||||||
|
|
||||||
|
# 画像の解像度、latentをあらかじめ取得する
|
||||||
|
img_ar_errors = []
|
||||||
|
self.size_lat_cache = {}
|
||||||
|
for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions):
|
||||||
|
if image_path in self.size_lat_cache:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image = self.load_image(image_path)[0]
|
||||||
|
image_height, image_width = image.shape[0:2]
|
||||||
|
|
||||||
|
if not enable_bucket:
|
||||||
|
# assert image_width == self.width and image_height == self.height, \
|
||||||
|
# f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}"
|
||||||
|
reso = (self.width, self.height)
|
||||||
|
else:
|
||||||
|
# bucketを決める
|
||||||
|
aspect_ratio = image_width / image_height
|
||||||
|
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||||
|
bucket_id = np.abs(ar_errors).argmin()
|
||||||
|
reso = bucket_resos[bucket_id]
|
||||||
|
ar_error = ar_errors[bucket_id]
|
||||||
|
img_ar_errors.append(ar_error)
|
||||||
|
|
||||||
|
if cache_latents:
|
||||||
|
image = self.resize_and_trim(image, reso)
|
||||||
|
|
||||||
|
# latentを取得する
|
||||||
|
if cache_latents:
|
||||||
|
img_tensor = self.image_transforms(image)
|
||||||
|
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
||||||
|
latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
||||||
|
else:
|
||||||
|
latents = None
|
||||||
|
|
||||||
|
self.size_lat_cache[image_path] = (reso, latents)
|
||||||
|
|
||||||
|
# 画像をbucketに分割する
|
||||||
|
self.buckets = [[] for _ in range(len(bucket_resos))]
|
||||||
|
reso_to_index = {}
|
||||||
|
for i, reso in enumerate(bucket_resos):
|
||||||
|
reso_to_index[reso] = i
|
||||||
|
|
||||||
|
def split_to_buckets(is_reg, img_path_captions):
|
||||||
|
for image_path, caption in img_path_captions:
|
||||||
|
reso, _ = self.size_lat_cache[image_path]
|
||||||
|
bucket_index = reso_to_index[reso]
|
||||||
|
self.buckets[bucket_index].append((is_reg, image_path, caption))
|
||||||
|
|
||||||
|
split_to_buckets(False, self.train_img_path_captions)
|
||||||
|
|
||||||
|
if self.enable_reg_images:
|
||||||
|
l = []
|
||||||
|
while len(l) < len(self.train_img_path_captions):
|
||||||
|
l += self.reg_img_path_captions
|
||||||
|
l = l[:len(self.train_img_path_captions)]
|
||||||
|
split_to_buckets(True, l)
|
||||||
|
|
||||||
|
if enable_bucket:
|
||||||
|
print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数")
|
||||||
|
for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)):
|
||||||
|
print(f"bucket {i}: resolution {reso}, count: {len(imgs)}")
|
||||||
|
img_ar_errors = np.array(img_ar_errors)
|
||||||
|
print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}")
|
||||||
|
|
||||||
|
# 参照用indexを作る
|
||||||
|
self.buckets_indices = []
|
||||||
|
for bucket_index, bucket in enumerate(self.buckets):
|
||||||
|
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||||
|
for batch_index in range(batch_count):
|
||||||
|
self.buckets_indices.append((bucket_index, batch_index))
|
||||||
|
|
||||||
|
self.shuffle_buckets()
|
||||||
|
self._length = len(self.buckets_indices)
|
||||||
|
|
||||||
|
# どのサイズにリサイズするか→トリミングする方向で
|
||||||
|
def resize_and_trim(self, image, reso):
|
||||||
|
image_height, image_width = image.shape[0:2]
|
||||||
|
ar_img = image_width / image_height
|
||||||
|
ar_reso = reso[0] / reso[1]
|
||||||
|
if ar_img > ar_reso: # 横が長い→縦を合わせる
|
||||||
|
scale = reso[1] / image_height
|
||||||
|
else:
|
||||||
|
scale = reso[0] / image_width
|
||||||
|
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
|
||||||
|
|
||||||
|
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||||
|
if resized_size[0] > reso[0]:
|
||||||
|
trim_size = resized_size[0] - reso[0]
|
||||||
|
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||||
|
elif resized_size[1] > reso[1]:
|
||||||
|
trim_size = resized_size[1] - reso[1]
|
||||||
|
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||||
|
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \
|
||||||
|
f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||||
|
return image
|
||||||
|
|
||||||
|
def shuffle_buckets(self):
|
||||||
|
random.shuffle(self.buckets_indices)
|
||||||
|
for bucket in self.buckets:
|
||||||
|
random.shuffle(bucket)
|
||||||
|
|
||||||
def load_image(self, image_path):
|
def load_image(self, image_path):
|
||||||
image = Image.open(image_path)
|
image = Image.open(image_path)
|
||||||
if not image.mode == "RGB":
|
if not image.mode == "RGB":
|
||||||
@ -184,41 +335,32 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._length
|
return self._length
|
||||||
|
|
||||||
def set_cached_latents(self, image_path, latents):
|
def __getitem__(self, index):
|
||||||
if self.latents_cache is None:
|
if index == 0:
|
||||||
self.latents_cache = {}
|
self.shuffle_buckets()
|
||||||
self.latents_cache[image_path] = latents
|
|
||||||
|
|
||||||
def __getitem__(self, index_arg):
|
bucket = self.buckets[self.buckets_indices[index][0]]
|
||||||
example = {}
|
image_index = self.buckets_indices[index][1] * self.batch_size
|
||||||
|
|
||||||
if not self.enable_reg_images:
|
latents_list = []
|
||||||
index = index_arg
|
images = []
|
||||||
img_path_captions = self.train_img_path_captions
|
captions = []
|
||||||
reg = False
|
loss_weights = []
|
||||||
else:
|
|
||||||
# 偶数ならtrain、奇数ならregを返す
|
|
||||||
if index_arg % 2 == 0:
|
|
||||||
img_path_captions = self.train_img_path_captions
|
|
||||||
reg = False
|
|
||||||
else:
|
|
||||||
img_path_captions = self.reg_img_path_captions
|
|
||||||
reg = True
|
|
||||||
index = index_arg // 2
|
|
||||||
example['loss_weight'] = 1.0 if (not reg or self.fine_tuning) else self.prior_loss_weight
|
|
||||||
|
|
||||||
index = index % len(img_path_captions)
|
for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]:
|
||||||
image_path, caption = img_path_captions[index]
|
loss_weights.append(1.0 if is_reg else self.prior_loss_weight)
|
||||||
example['image_path'] = image_path
|
|
||||||
|
|
||||||
# image/latentsを処理する
|
# image/latentsを処理する
|
||||||
if self.latents_cache is not None and image_path in self.latents_cache:
|
reso, latents = self.size_lat_cache[image_path]
|
||||||
# latentsはキャッシュ済み
|
|
||||||
example['latents'] = self.latents_cache[image_path]
|
if latents is None:
|
||||||
else:
|
|
||||||
# 画像を読み込み必要ならcropする
|
# 画像を読み込み必要ならcropする
|
||||||
img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
|
img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
|
||||||
im_h, im_w = img.shape[0:2]
|
im_h, im_w = img.shape[0:2]
|
||||||
|
|
||||||
|
if self.enable_bucket:
|
||||||
|
img = self.resize_and_trim(img, reso)
|
||||||
|
else:
|
||||||
if face_cx > 0: # 顔位置情報あり
|
if face_cx > 0: # 顔位置情報あり
|
||||||
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
||||||
elif im_h > self.height or im_w > self.width:
|
elif im_h > self.height or im_w > self.width:
|
||||||
@ -231,36 +373,47 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
img = img[:, p:p + self.width]
|
img = img[:, p:p + self.width]
|
||||||
|
|
||||||
im_h, im_w = img.shape[0:2]
|
im_h, im_w = img.shape[0:2]
|
||||||
assert im_h == self.height and im_w == self.width, f"image too small / 画像サイズが小さいようです: {image_path}"
|
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}"
|
||||||
|
|
||||||
# augmentation
|
# augmentation
|
||||||
if self.aug is not None:
|
if self.aug is not None:
|
||||||
img = self.aug(image=img)['image']
|
img = self.aug(image=img)['image']
|
||||||
|
|
||||||
example['image'] = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||||
|
else:
|
||||||
|
image = None
|
||||||
|
|
||||||
|
images.append(image)
|
||||||
|
latents_list.append(latents)
|
||||||
|
|
||||||
# captionを処理する
|
# captionを処理する
|
||||||
if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする
|
if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする
|
||||||
tokens = caption.strip().split(",")
|
tokens = caption.strip().split(",")
|
||||||
random.shuffle(tokens)
|
random.shuffle(tokens)
|
||||||
caption = ",".join(tokens).strip()
|
caption = ",".join(tokens).strip()
|
||||||
|
captions.append(caption)
|
||||||
|
|
||||||
input_ids = self.tokenizer(caption, padding="do_not_pad", truncation=True,
|
# input_idsをpadしてTensor変換
|
||||||
max_length=self.tokenizer.model_max_length).input_ids
|
|
||||||
|
|
||||||
# padしてTensor変換
|
|
||||||
if self.disable_padding:
|
if self.disable_padding:
|
||||||
# paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?)
|
# paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?)
|
||||||
input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||||
else:
|
else:
|
||||||
# paddingする
|
# paddingする
|
||||||
input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding='max_length', max_length=self.tokenizer.model_max_length,
|
input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
|
||||||
return_tensors='pt').input_ids
|
|
||||||
|
|
||||||
|
example = {}
|
||||||
|
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
||||||
example['input_ids'] = input_ids
|
example['input_ids'] = input_ids
|
||||||
|
if images[0] is not None:
|
||||||
|
images = torch.stack(images)
|
||||||
|
images = images.to(memory_format=torch.contiguous_format).float()
|
||||||
|
else:
|
||||||
|
images = None
|
||||||
|
example['images'] = images
|
||||||
|
example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
|
||||||
if self.debug_dataset:
|
if self.debug_dataset:
|
||||||
example['caption'] = caption
|
example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]]
|
||||||
|
example['captions'] = captions
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
|
||||||
@ -916,7 +1069,7 @@ def load_models_from_stable_diffusion_checkpoint(ckpt_path):
|
|||||||
return text_model, vae, unet
|
return text_model, vae, unet
|
||||||
|
|
||||||
|
|
||||||
def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps):
|
def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
|
||||||
# VAEがメモリ上にないので、もう一度VAEを含めて読み込む
|
# VAEがメモリ上にないので、もう一度VAEを含めて読み込む
|
||||||
checkpoint = load_checkpoint_with_conversion(ckpt_path)
|
checkpoint = load_checkpoint_with_conversion(ckpt_path)
|
||||||
state_dict = checkpoint["state_dict"]
|
state_dict = checkpoint["state_dict"]
|
||||||
@ -926,6 +1079,8 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
|
|||||||
for k, v in unet_state_dict.items():
|
for k, v in unet_state_dict.items():
|
||||||
key = "model.diffusion_model." + k
|
key = "model.diffusion_model." + k
|
||||||
assert key in state_dict, f"Illegal key in save SD: {key}"
|
assert key in state_dict, f"Illegal key in save SD: {key}"
|
||||||
|
if save_dtype is not None:
|
||||||
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
state_dict[key] = v
|
state_dict[key] = v
|
||||||
|
|
||||||
# Convert the text encoder model
|
# Convert the text encoder model
|
||||||
@ -933,6 +1088,8 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
|
|||||||
for k, v in text_enc_dict.items():
|
for k, v in text_enc_dict.items():
|
||||||
key = "cond_stage_model.transformer." + k
|
key = "cond_stage_model.transformer." + k
|
||||||
assert key in state_dict, f"Illegal key in save SD: {key}"
|
assert key in state_dict, f"Illegal key in save SD: {key}"
|
||||||
|
if save_dtype is not None:
|
||||||
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
state_dict[key] = v
|
state_dict[key] = v
|
||||||
|
|
||||||
# Put together new checkpoint
|
# Put together new checkpoint
|
||||||
@ -951,24 +1108,7 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
|
|||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
input_ids = [e['input_ids'] for e in examples]
|
return examples[0]
|
||||||
input_ids = torch.stack(input_ids)
|
|
||||||
|
|
||||||
if 'latents' in examples[0]:
|
|
||||||
pixel_values = None
|
|
||||||
latents = [e['latents'] for e in examples]
|
|
||||||
latents = torch.stack(latents)
|
|
||||||
else:
|
|
||||||
pixel_values = [e['image'] for e in examples]
|
|
||||||
pixel_values = torch.stack(pixel_values)
|
|
||||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
|
||||||
latents = None
|
|
||||||
|
|
||||||
loss_weights = [e['loss_weight'] for e in examples]
|
|
||||||
loss_weights = torch.FloatTensor(loss_weights)
|
|
||||||
|
|
||||||
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "latents": latents, "loss_weights": loss_weights}
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
@ -998,19 +1138,22 @@ def train(args):
|
|||||||
try:
|
try:
|
||||||
n_repeats = int(tokens[0])
|
n_repeats = int(tokens[0])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
|
# print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
|
||||||
raise e
|
# raise e
|
||||||
|
return 0, []
|
||||||
|
|
||||||
caption = '_'.join(tokens[1:])
|
caption = '_'.join(tokens[1:])
|
||||||
|
|
||||||
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg"))
|
print(f"found directory {n_repeats}_{caption}")
|
||||||
|
|
||||||
|
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + glob.glob(os.path.join(dir, "*.webp"))
|
||||||
return n_repeats, [(ip, caption) for ip in img_paths]
|
return n_repeats, [(ip, caption) for ip in img_paths]
|
||||||
|
|
||||||
print("prepare train images.")
|
print("prepare train images.")
|
||||||
train_img_path_captions = []
|
train_img_path_captions = []
|
||||||
|
|
||||||
if fine_tuning:
|
if fine_tuning:
|
||||||
img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg"))
|
img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
for img_path in tqdm(img_paths):
|
for img_path in tqdm(img_paths):
|
||||||
# captionの候補ファイル名を作る
|
# captionの候補ファイル名を作る
|
||||||
base_name = os.path.splitext(img_path)[0]
|
base_name = os.path.splitext(img_path)[0]
|
||||||
@ -1042,7 +1185,7 @@ def train(args):
|
|||||||
n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
|
n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
|
||||||
for _ in range(n_repeats):
|
for _ in range(n_repeats):
|
||||||
train_img_path_captions.extend(img_caps)
|
train_img_path_captions.extend(img_caps)
|
||||||
print(f"{len(train_img_path_captions)} train images.")
|
print(f"{len(train_img_path_captions)} train images with repeating.")
|
||||||
|
|
||||||
reg_img_path_captions = []
|
reg_img_path_captions = []
|
||||||
if args.reg_data_dir:
|
if args.reg_data_dir:
|
||||||
@ -1054,11 +1197,6 @@ def train(args):
|
|||||||
reg_img_path_captions.extend(img_caps)
|
reg_img_path_captions.extend(img_caps)
|
||||||
print(f"{len(reg_img_path_captions)} reg images.")
|
print(f"{len(reg_img_path_captions)} reg images.")
|
||||||
|
|
||||||
if args.debug_dataset:
|
|
||||||
# デバッグ時はshuffleして実際のデータセット使用時に近づける(学習時はdata loaderでshuffleする)
|
|
||||||
random.shuffle(train_img_path_captions)
|
|
||||||
random.shuffle(reg_img_path_captions)
|
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
resolution = tuple([int(r) for r in args.resolution.split(',')])
|
resolution = tuple([int(r) for r in args.resolution.split(',')])
|
||||||
if len(resolution) == 1:
|
if len(resolution) == 1:
|
||||||
@ -1078,29 +1216,40 @@ def train(args):
|
|||||||
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
||||||
|
|
||||||
print("prepare dataset")
|
print("prepare dataset")
|
||||||
train_dataset = DreamBoothOrFineTuningDataset(fine_tuning, train_img_path_captions,
|
train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution,
|
||||||
reg_img_path_captions, tokenizer, resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.shuffle_caption, args.no_token_padding, args.debug_dataset)
|
args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop,
|
||||||
|
args.shuffle_caption, args.no_token_padding, args.debug_dataset)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
|
train_dataset.make_buckets_with_caching(args.enable_bucket, None) # デバッグ用にcacheなしで作る
|
||||||
|
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||||
print("Escape for exit. / Escキーで中断、終了します")
|
print("Escape for exit. / Escキーで中断、終了します")
|
||||||
for example in train_dataset:
|
for example in train_dataset:
|
||||||
im = example['image']
|
for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']):
|
||||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||||
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
||||||
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
||||||
print(f'caption: "{example["caption"]}", loss weight: {example["loss_weight"]}')
|
print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}')
|
||||||
cv2.imshow("img", im)
|
cv2.imshow("img", im)
|
||||||
k = cv2.waitKey()
|
k = cv2.waitKey()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
if k == 27:
|
if k == 27:
|
||||||
break
|
break
|
||||||
|
if k == 27:
|
||||||
|
break
|
||||||
return
|
return
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
# gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
|
# gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
|
||||||
print("prepare accelerator")
|
print("prepare accelerator")
|
||||||
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision)
|
if args.logging_dir is None:
|
||||||
|
log_with = None
|
||||||
|
logging_dir = None
|
||||||
|
else:
|
||||||
|
log_with = "tensorboard"
|
||||||
|
logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||||
|
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision,
|
||||||
|
log_with=log_with, logging_dir=logging_dir)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
@ -1122,28 +1271,24 @@ def train(args):
|
|||||||
elif args.mixed_precision == "bf16":
|
elif args.mixed_precision == "bf16":
|
||||||
weight_dtype = torch.bfloat16
|
weight_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
save_dtype = None
|
||||||
|
if args.save_precision == "fp16":
|
||||||
|
save_dtype = torch.float16
|
||||||
|
elif args.save_precision == "bf16":
|
||||||
|
save_dtype = torch.bfloat16
|
||||||
|
elif args.save_precision == "float":
|
||||||
|
save_dtype = torch.float32
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
# latentをcacheする→新しいDatasetを作るとcaptionのshuffleが効かないので元のDatasetにcacheを持つ(cascadeする手もあるが)
|
|
||||||
print("caching latents.")
|
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
for i in tqdm(range(len(train_dataset))):
|
|
||||||
example = train_dataset[i]
|
|
||||||
if 'latents' not in example:
|
|
||||||
image_path = example['image_path']
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pixel_values = example["image"].unsqueeze(0).to(device=accelerator.device, dtype=weight_dtype)
|
train_dataset.make_buckets_with_caching(args.enable_bucket, vae)
|
||||||
latents = vae.encode(pixel_values).latent_dist.sample().squeeze(0).to("cpu")
|
|
||||||
train_dataset.set_cached_latents(image_path, latents)
|
|
||||||
# assertion
|
|
||||||
for i in range(len(train_dataset)):
|
|
||||||
assert 'latents' in train_dataset[i], "internal error: latents not cached"
|
|
||||||
|
|
||||||
del vae
|
del vae
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
else:
|
else:
|
||||||
|
train_dataset.make_buckets_with_caching(args.enable_bucket, None)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
@ -1173,7 +1318,7 @@ def train(args):
|
|||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=n_workers)
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
|
lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
|
||||||
@ -1185,6 +1330,11 @@ def train(args):
|
|||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# resumeする
|
||||||
|
if args.resume is not None:
|
||||||
|
print(f"resume training from state: {args.resume}")
|
||||||
|
accelerator.load_state(args.resume)
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
|
num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
|
||||||
|
|
||||||
@ -1193,7 +1343,7 @@ def train(args):
|
|||||||
print("running training / 学習開始")
|
print("running training / 学習開始")
|
||||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
||||||
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
||||||
print(f" num examples / サンプル数: {len(train_dataset)}")
|
print(f" num examples / サンプル数: {train_dataset.num_train_images * 2}")
|
||||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
@ -1222,7 +1372,7 @@ def train(args):
|
|||||||
if cache_latents:
|
if cache_latents:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
else:
|
else:
|
||||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
|
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
@ -1271,15 +1421,22 @@ def train(args):
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
loss_total += current_loss
|
loss_total += current_loss
|
||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
# accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
||||||
|
accelerator.log(logs, step=epoch+1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if use_stable_diffusion_format and args.save_every_n_epochs is not None:
|
if use_stable_diffusion_format and args.save_every_n_epochs is not None:
|
||||||
@ -1288,7 +1445,11 @@ def train(args):
|
|||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
|
ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
|
||||||
save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
|
save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
|
||||||
args.pretrained_model_name_or_path, epoch + 1, global_step)
|
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
|
||||||
|
|
||||||
|
if args.save_state:
|
||||||
|
print("saving state.")
|
||||||
|
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
@ -1296,6 +1457,11 @@ def train(args):
|
|||||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
|
if args.save_state:
|
||||||
|
print("saving last state.")
|
||||||
|
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
|
||||||
|
|
||||||
del accelerator # この後メモリを使うのでこれは消す
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
@ -1303,7 +1469,8 @@ def train(args):
|
|||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
|
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
|
||||||
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
||||||
save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step)
|
save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet,
|
||||||
|
args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
|
||||||
else:
|
else:
|
||||||
# Create the pipeline using using the trained modules and save it.
|
# Create the pipeline using using the trained modules and save it.
|
||||||
print(f"save trained model as Diffusers to {args.output_dir}")
|
print(f"save trained model as Diffusers to {args.output_dir}")
|
||||||
@ -1589,6 +1756,10 @@ if __name__ == '__main__':
|
|||||||
help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
|
help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
|
||||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||||
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
|
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
|
||||||
|
parser.add_argument("--save_state", action="store_true",
|
||||||
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||||
|
parser.add_argument("--resume", type=str, default=None,
|
||||||
|
help="saved state to resume training / 学習再開するモデルのstate")
|
||||||
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
|
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
|
||||||
parser.add_argument("--no_token_padding", action="store_true",
|
parser.add_argument("--no_token_padding", action="store_true",
|
||||||
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
||||||
@ -1612,6 +1783,8 @@ if __name__ == '__main__':
|
|||||||
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||||
parser.add_argument("--cache_latents", action="store_true",
|
parser.add_argument("--cache_latents", action="store_true",
|
||||||
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
||||||
|
parser.add_argument("--enable_bucket", action="store_true",
|
||||||
|
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
|
||||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||||
@ -1619,8 +1792,12 @@ if __name__ == '__main__':
|
|||||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||||
|
parser.add_argument("--save_precision", type=str, default=None,
|
||||||
|
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
||||||
parser.add_argument("--clip_skip", type=int, default=None,
|
parser.add_argument("--clip_skip", type=int, default=None,
|
||||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||||
|
parser.add_argument("--logging_dir", type=str, default=None,
|
||||||
|
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
Loading…
Reference in New Issue
Block a user