Update train_db_fixed to v9

This commit is contained in:
Bernard Maltais 2022-11-19 08:49:42 -05:00
parent f56340d53e
commit 0e8b993def
7 changed files with 593 additions and 146 deletions

View File

@ -113,9 +113,10 @@ accelerate launch --num_cpu_threads_per_process 6 train_db_fixed-ber.py `
--cache_latents ` --cache_latents `
--save_every_n_epochs=1 ` --save_every_n_epochs=1 `
--fine_tuning ` --fine_tuning `
--enable_bucket `
--dataset_repeats=200 ` --dataset_repeats=200 `
--seed=23 ` --seed=23 `
--save_half ---save_precision="fp16"
``` ```
Refer to this url for more details about finetuning: https://note.com/kohya_ss/n/n1269f1e1a54e Refer to this url for more details about finetuning: https://note.com/kohya_ss/n/n1269f1e1a54e
@ -125,7 +126,12 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n
* 11/7 (v7): Text Encoder supports checkpoint files in different storage formats (it is converted at the time of import, so export will be in normal format). Changed the average value of EPOCH loss to output to the screen. Added a function to save epoch and global step in checkpoint in SD format (add values if there is existing data). The reg_data_dir option is enabled during fine tuning (fine tuning while mixing regularized images). Added dataset_repeats option that is valid for fine tuning (specified when the number of teacher images is small and the epoch is extremely short). * 11/7 (v7): Text Encoder supports checkpoint files in different storage formats (it is converted at the time of import, so export will be in normal format). Changed the average value of EPOCH loss to output to the screen. Added a function to save epoch and global step in checkpoint in SD format (add values if there is existing data). The reg_data_dir option is enabled during fine tuning (fine tuning while mixing regularized images). Added dataset_repeats option that is valid for fine tuning (specified when the number of teacher images is small and the epoch is extremely short).
* 11/9 (v8): supports Diffusers 0.7.2. To upgrade diffusers run `pip install --upgrade diffusers[torch]` * 11/9 (v8): supports Diffusers 0.7.2. To upgrade diffusers run `pip install --upgrade diffusers[torch]`
* 11/14 (diffusers_fine_tuning v2): * 11/14 (diffusers_fine_tuning v2):
- script name is now fine_tune.py. - script name is now fine_tune.py.
- Added option to learn Text Encoder --train_text_encoder. - Added option to learn Text Encoder --train_text_encoder.
- The data format of checkpoint at the time of saving can be specified with the --save_precision option. You can choose float, fp16, and bf16. - The data format of checkpoint at the time of saving can be specified with the --save_precision option. You can choose float, fp16, and bf16.
- Added a --save_state option to save the learning state (optimizer, etc.) in the middle. It can be resumed with the --resume option. - Added a --save_state option to save the learning state (optimizer, etc.) in the middle. It can be resumed with the --resume option.
* 11/18 (v9):
- Added support for Aspect Ratio Bucketing (enable_bucket option). (--enable_bucket)
- Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision)
- Added support for saving learning state (--save_state, --resume)
- Added support for logging (--logging_dir)

View File

@ -2,9 +2,12 @@
# #
# Usefull to create base caption that will be augmented on a per image basis # Usefull to create base caption that will be augmented on a per image basis
$folder = "D:\dreambooth\train_sylvia_ritter\raw_data\all-images\" $folder = "D:\some\folder\location\"
$file_pattern="*.*" $file_pattern="*.*"
$text_fir_file="a digital painting of xxx, by silvery trait" $caption_text="some caption text"
$files = Get-ChildItem $folder$file_pattern $files = Get-ChildItem $folder$file_pattern -Include *.png,*.jpg,*.webp -File
foreach ($file in $files) {New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $text_fir_file} foreach ($file in $files)
{
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text
}

View File

@ -0,0 +1,20 @@
# This powershell script will create a text file for each files in the folder
#
# Usefull to create base caption that will be augmented on a per image basis
$folder = "D:\test\t2\"
$file_pattern="*.*"
$text_fir_file="bigeyes style"
foreach ($file in Get-ChildItem $folder\$file_pattern -File)
{
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $text_fir_file
}
foreach($directory in Get-ChildItem -path $folder -Directory)
{
foreach ($file in Get-ChildItem $folder\$directory\$file_pattern)
{
New-Item -ItemType file -Path $folder\$directory -Name "$($file.BaseName).txt" -Value $text_fir_file
}
}

View File

@ -0,0 +1,87 @@
# This powershell script will create a model using the fine tuning dreambooth method. It will require landscape,
# portrait and square images.
#
# Adjust the script to your own needs
# Sylvia Ritter
# variable values
$pretrained_model_name_or_path = "D:\models\v1-5-pruned-mse-vae.ckpt"
$data_dir = "D:\test\squat"
$train_dir = "D:\test\"
$resolution = "512,512"
$image_num = Get-ChildItem $data_dir -Recurse -File -Include *.png | Measure-Object | %{$_.Count}
Write-Output "image_num: $image_num"
$learning_rate = 1e-6
$dataset_repeats = 40
$train_batch_size = 8
$epoch = 1
$save_every_n_epochs=1
$mixed_precision="fp16"
$num_cpu_threads_per_process=6
# You should not have to change values past this point
$output_dir = $train_dir + "\model"
$repeats = $image_num * $dataset_repeats
$mts = [Math]::Ceiling($repeats / $train_batch_size * $epoch)
Write-Output "Repeats: $repeats"
.\venv\Scripts\activate
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed-ber.py `
--pretrained_model_name_or_path=$pretrained_model_name_or_path `
--train_data_dir=$data_dir `
--output_dir=$output_dir `
--resolution=$resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$mts `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
--cache_latents `
--save_every_n_epochs=$save_every_n_epochs `
--fine_tuning `
--dataset_repeats=$dataset_repeats `
--save_precision="fp16"
# 2nd pass at half the dataset repeat value
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
--pretrained_model_name_or_path=$output_dir"\last.ckpt" `
--train_data_dir=$data_dir `
--output_dir=$output_dir"2" `
--resolution=$resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$([Math]::Ceiling($mts/2)) `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
--cache_latents `
--save_every_n_epochs=$save_every_n_epochs `
--fine_tuning `
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
--save_precision="fp16"
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed-ber.py `
--pretrained_model_name_or_path=$output_dir"\last.ckpt" `
--train_data_dir=$data_dir `
--output_dir=$output_dir"2" `
--resolution=$resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$mts `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
--cache_latents `
--save_every_n_epochs=$save_every_n_epochs `
--fine_tuning `
--dataset_repeats=$dataset_repeats `
--save_precision="fp16"

View File

@ -0,0 +1,154 @@
# This powershell script will create a model using the fine tuning dreambooth method. It will require landscape,
# portrait and square images.
#
# Adjust the script to your own needs
# Sylvia Ritter
# variable values
$pretrained_model_name_or_path = "D:\models\v1-5-pruned-mse-vae.ckpt"
$train_dir = "D:\dreambooth\train_sylvia_ritter\raw_data"
$landscape_image_num = 4
$portrait_image_num = 25
$square_image_num = 2
$learning_rate = 1e-6
$dataset_repeats = 120
$train_batch_size = 4
$epoch = 1
$save_every_n_epochs=1
$mixed_precision="fp16"
$num_cpu_threads_per_process=6
$landscape_folder_name = "landscape-pp"
$landscape_resolution = "832,512"
$portrait_folder_name = "portrait-pp"
$portrait_resolution = "448,896"
$square_folder_name = "square-pp"
$square_resolution = "512,512"
# You should not have to change values past this point
$landscape_data_dir = $train_dir + "\" + $landscape_folder_name
$portrait_data_dir = $train_dir + "\" + $portrait_folder_name
$square_data_dir = $train_dir + "\" + $square_folder_name
$landscape_output_dir = $train_dir + "\model-l"
$portrait_output_dir = $train_dir + "\model-lp"
$square_output_dir = $train_dir + "\model-lps"
$landscape_repeats = $landscape_image_num * $dataset_repeats
$portrait_repeats = $portrait_image_num * $dataset_repeats
$square_repeats = $square_image_num * $dataset_repeats
$landscape_mts = [Math]::Ceiling($landscape_repeats / $train_batch_size * $epoch)
$portrait_mts = [Math]::Ceiling($portrait_repeats / $train_batch_size * $epoch)
$square_mts = [Math]::Ceiling($square_repeats / $train_batch_size * $epoch)
# Write-Output $landscape_repeats
.\venv\Scripts\activate
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
--pretrained_model_name_or_path=$pretrained_model_name_or_path `
--train_data_dir=$landscape_data_dir `
--output_dir=$landscape_output_dir `
--resolution=$landscape_resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$landscape_mts `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
--cache_latents `
--save_every_n_epochs=$save_every_n_epochs `
--fine_tuning `
--dataset_repeats=$dataset_repeats `
--save_precision="fp16"
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
--pretrained_model_name_or_path=$landscape_output_dir"\last.ckpt" `
--train_data_dir=$portrait_data_dir `
--output_dir=$portrait_output_dir `
--resolution=$portrait_resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$portrait_mts `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
--cache_latents `
--save_every_n_epochs=$save_every_n_epochs `
--fine_tuning `
--dataset_repeats=$dataset_repeats `
--save_precision="fp16"
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
--pretrained_model_name_or_path=$portrait_output_dir"\last.ckpt" `
--train_data_dir=$square_data_dir `
--output_dir=$square_output_dir `
--resolution=$square_resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$square_mts `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
--cache_latents `
--save_every_n_epochs=$save_every_n_epochs `
--fine_tuning `
--dataset_repeats=$dataset_repeats `
--save_precision="fp16"
# 2nd pass at half the dataset repeat value
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
--pretrained_model_name_or_path=$square_output_dir"\last.ckpt" `
--train_data_dir=$landscape_data_dir `
--output_dir=$landscape_output_dir"2" `
--resolution=$landscape_resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$([Math]::Ceiling($landscape_mts/2)) `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
--cache_latents `
--save_every_n_epochs=$save_every_n_epochs `
--fine_tuning `
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
--save_precision="fp16"
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
--pretrained_model_name_or_path=$landscape_output_dir"2\last.ckpt" `
--train_data_dir=$portrait_data_dir `
--output_dir=$portrait_output_dir"2" `
--resolution=$portrait_resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$([Math]::Ceiling($portrait_mts/2)) `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
--cache_latents `
--save_every_n_epochs=$save_every_n_epochs `
--fine_tuning `
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
--save_precision="fp16"
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process train_db_fixed.py `
--pretrained_model_name_or_path=$portrait_output_dir"2\last.ckpt" `
--train_data_dir=$square_data_dir `
--output_dir=$square_output_dir"2" `
--resolution=$square_resolution `
--train_batch_size=$train_batch_size `
--learning_rate=$learning_rate `
--max_train_steps=$([Math]::Ceiling($square_mts/2)) `
--use_8bit_adam `
--xformers `
--mixed_precision=$mixed_precision `
--cache_latents `
--save_every_n_epochs=$save_every_n_epochs `
--fine_tuning `
--dataset_repeats=$([Math]::Ceiling($dataset_repeats/2)) `
--save_precision="fp16"

View File

@ -55,7 +55,7 @@ accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\
--use_8bit_adam --xformers ` --use_8bit_adam --xformers `
--mixed_precision=$mixed_precision ` --mixed_precision=$mixed_precision `
--save_every_n_epochs=$save_every_n_epochs ` --save_every_n_epochs=$save_every_n_epochs `
--save_half --save_precision="fp16"
accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\kohya_ss\diffusers_fine_tuning\fine_tune.py ` accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\kohya_ss\diffusers_fine_tuning\fine_tune.py `
--pretrained_model_name_or_path=$train_dir"\fine_tuned\last.ckpt" ` --pretrained_model_name_or_path=$train_dir"\fine_tuned\last.ckpt" `
@ -69,4 +69,4 @@ accelerate launch --num_cpu_threads_per_process $num_cpu_threads_per_process D:\
--use_8bit_adam --xformers ` --use_8bit_adam --xformers `
--mixed_precision=$mixed_precision ` --mixed_precision=$mixed_precision `
--save_every_n_epochs=$save_every_n_epochs ` --save_every_n_epochs=$save_every_n_epochs `
--save_half --save_precision="fp16"

View File

@ -4,7 +4,9 @@
# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images, # v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
# enable reg images in fine-tuning, add dataset_repeats option # enable reg images in fine-tuning, add dataset_repeats option
# v8: supports Diffusers 0.7.2 # v8: supports Diffusers 0.7.2
# v9: add bucketing option
import time
from torch.autograd.function import Function from torch.autograd.function import Function
import argparse import argparse
import glob import glob
@ -56,13 +58,40 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2
# checkpointファイル名 # checkpointファイル名
LAST_CHECKPOINT_NAME = "last.ckpt" LAST_CHECKPOINT_NAME = "last.ckpt"
LAST_STATE_NAME = "last-state"
EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt" EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
EPOCH_STATE_NAME = "epoch-{:06d}-state"
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
max_width, max_height = max_reso
max_area = (max_width // divisible) * (max_height // divisible)
resos = set()
size = int(math.sqrt(max_area)) * divisible
resos.add((size, size))
size = min_size
while size <= max_size:
width = size
height = min(max_size, (max_area // (width // divisible)) * divisible)
resos.add((width, height))
resos.add((height, width))
size += divisible
resos = list(resos)
resos.sort()
aspect_ratios = [w / h for w, h in resos]
return resos, aspect_ratios
class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset): class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
def __init__(self, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None: def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
super().__init__() super().__init__()
self.batch_size = batch_size
self.fine_tuning = fine_tuning self.fine_tuning = fine_tuning
self.train_img_path_captions = train_img_path_captions self.train_img_path_captions = train_img_path_captions
self.reg_img_path_captions = reg_img_path_captions self.reg_img_path_captions = reg_img_path_captions
@ -76,6 +105,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
self.shuffle_caption = shuffle_caption self.shuffle_caption = shuffle_caption
self.disable_padding = disable_padding self.disable_padding = disable_padding
self.latents_cache = None self.latents_cache = None
self.enable_bucket = False
# augmentation # augmentation
flip_p = 0.5 if flip_aug else 0.0 flip_p = 0.5 if flip_aug else 0.0
@ -102,12 +132,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
self.enable_reg_images = self.num_reg_images > 0 self.enable_reg_images = self.num_reg_images > 0
if not self.enable_reg_images: if self.enable_reg_images and self.num_train_images < self.num_reg_images:
self._length = self.num_train_images
else:
# 学習データの倍として、奇数ならtrain
self._length = self.num_train_images * 2
if self._length // 2 < self.num_reg_images:
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
self.image_transforms = transforms.Compose( self.image_transforms = transforms.Compose(
@ -117,6 +142,132 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
] ]
) )
# bucketingを行わない場合も呼び出し必須ひとつだけbucketを作る
def make_buckets_with_caching(self, enable_bucket, vae):
self.enable_bucket = enable_bucket
cache_latents = vae is not None
if cache_latents:
if enable_bucket:
print("cache latents with bucketing")
else:
print("cache latents")
else:
if enable_bucket:
print("make buckets")
else:
print("prepare dataset")
# bucketingを用意する
if enable_bucket:
bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height))
else:
# bucketはひとつだけ、すべての画像は同じ解像度
bucket_resos = [(self.width, self.height)]
bucket_aspect_ratios = [self.width / self.height]
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
# 画像の解像度、latentをあらかじめ取得する
img_ar_errors = []
self.size_lat_cache = {}
for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions):
if image_path in self.size_lat_cache:
continue
image = self.load_image(image_path)[0]
image_height, image_width = image.shape[0:2]
if not enable_bucket:
# assert image_width == self.width and image_height == self.height, \
# f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}"
reso = (self.width, self.height)
else:
# bucketを決める
aspect_ratio = image_width / image_height
ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin()
reso = bucket_resos[bucket_id]
ar_error = ar_errors[bucket_id]
img_ar_errors.append(ar_error)
if cache_latents:
image = self.resize_and_trim(image, reso)
# latentを取得する
if cache_latents:
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
else:
latents = None
self.size_lat_cache[image_path] = (reso, latents)
# 画像をbucketに分割する
self.buckets = [[] for _ in range(len(bucket_resos))]
reso_to_index = {}
for i, reso in enumerate(bucket_resos):
reso_to_index[reso] = i
def split_to_buckets(is_reg, img_path_captions):
for image_path, caption in img_path_captions:
reso, _ = self.size_lat_cache[image_path]
bucket_index = reso_to_index[reso]
self.buckets[bucket_index].append((is_reg, image_path, caption))
split_to_buckets(False, self.train_img_path_captions)
if self.enable_reg_images:
l = []
while len(l) < len(self.train_img_path_captions):
l += self.reg_img_path_captions
l = l[:len(self.train_img_path_captions)]
split_to_buckets(True, l)
if enable_bucket:
print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数")
for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)):
print(f"bucket {i}: resolution {reso}, count: {len(imgs)}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}")
# 参照用indexを作る
self.buckets_indices = []
for bucket_index, bucket in enumerate(self.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
self.buckets_indices.append((bucket_index, batch_index))
self.shuffle_buckets()
self._length = len(self.buckets_indices)
# どのサイズにリサイズするか→トリミングする方向で
def resize_and_trim(self, image, reso):
image_height, image_width = image.shape[0:2]
ar_img = image_width / image_height
ar_reso = reso[0] / reso[1]
if ar_img > ar_reso: # 横が長い→縦を合わせる
scale = reso[1] / image_height
else:
scale = reso[0] / image_width
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size//2:trim_size//2 + reso[0]]
elif resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size//2:trim_size//2 + reso[1]]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \
f"internal error, illegal trimmed size: {image.shape}, {reso}"
return image
def shuffle_buckets(self):
random.shuffle(self.buckets_indices)
for bucket in self.buckets:
random.shuffle(bucket)
def load_image(self, image_path): def load_image(self, image_path):
image = Image.open(image_path) image = Image.open(image_path)
if not image.mode == "RGB": if not image.mode == "RGB":
@ -184,41 +335,32 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return self._length return self._length
def set_cached_latents(self, image_path, latents): def __getitem__(self, index):
if self.latents_cache is None: if index == 0:
self.latents_cache = {} self.shuffle_buckets()
self.latents_cache[image_path] = latents
def __getitem__(self, index_arg): bucket = self.buckets[self.buckets_indices[index][0]]
example = {} image_index = self.buckets_indices[index][1] * self.batch_size
if not self.enable_reg_images: latents_list = []
index = index_arg images = []
img_path_captions = self.train_img_path_captions captions = []
reg = False loss_weights = []
else:
# 偶数ならtrain、奇数ならregを返す
if index_arg % 2 == 0:
img_path_captions = self.train_img_path_captions
reg = False
else:
img_path_captions = self.reg_img_path_captions
reg = True
index = index_arg // 2
example['loss_weight'] = 1.0 if (not reg or self.fine_tuning) else self.prior_loss_weight
index = index % len(img_path_captions) for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]:
image_path, caption = img_path_captions[index] loss_weights.append(1.0 if is_reg else self.prior_loss_weight)
example['image_path'] = image_path
# image/latentsを処理する # image/latentsを処理する
if self.latents_cache is not None and image_path in self.latents_cache: reso, latents = self.size_lat_cache[image_path]
# latentsはキャッシュ済み
example['latents'] = self.latents_cache[image_path] if latents is None:
else:
# 画像を読み込み必要ならcropする # 画像を読み込み必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image(image_path) img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
im_h, im_w = img.shape[0:2] im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img = self.resize_and_trim(img, reso)
else:
if face_cx > 0: # 顔位置情報あり if face_cx > 0: # 顔位置情報あり
img = self.crop_target(img, face_cx, face_cy, face_w, face_h) img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
elif im_h > self.height or im_w > self.width: elif im_h > self.height or im_w > self.width:
@ -231,36 +373,47 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
img = img[:, p:p + self.width] img = img[:, p:p + self.width]
im_h, im_w = img.shape[0:2] im_h, im_w = img.shape[0:2]
assert im_h == self.height and im_w == self.width, f"image too small / 画像サイズが小さいようです: {image_path}" assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}"
# augmentation # augmentation
if self.aug is not None: if self.aug is not None:
img = self.aug(image=img)['image'] img = self.aug(image=img)['image']
example['image'] = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
else:
image = None
images.append(image)
latents_list.append(latents)
# captionを処理する # captionを処理する
if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする
tokens = caption.strip().split(",") tokens = caption.strip().split(",")
random.shuffle(tokens) random.shuffle(tokens)
caption = ",".join(tokens).strip() caption = ",".join(tokens).strip()
captions.append(caption)
input_ids = self.tokenizer(caption, padding="do_not_pad", truncation=True, # input_idsをpadしてTensor変換
max_length=self.tokenizer.model_max_length).input_ids
# padしてTensor変換
if self.disable_padding: if self.disable_padding:
# paddingしないpadding==Trueはバッチの中の最大長に合わせるだけやはりバグでは…… # paddingしないpadding==Trueはバッチの中の最大長に合わせるだけやはりバグでは……
input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
else: else:
# paddingする # paddingする
input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding='max_length', max_length=self.tokenizer.model_max_length, input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
return_tensors='pt').input_ids
example = {}
example['loss_weights'] = torch.FloatTensor(loss_weights)
example['input_ids'] = input_ids example['input_ids'] = input_ids
if images[0] is not None:
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
else:
images = None
example['images'] = images
example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
if self.debug_dataset: if self.debug_dataset:
example['caption'] = caption example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]]
example['captions'] = captions
return example return example
@ -916,7 +1069,7 @@ def load_models_from_stable_diffusion_checkpoint(ckpt_path):
return text_model, vae, unet return text_model, vae, unet
def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps): def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
# VAEがメモリ上にないので、もう一度VAEを含めて読み込む # VAEがメモリ上にないので、もう一度VAEを含めて読み込む
checkpoint = load_checkpoint_with_conversion(ckpt_path) checkpoint = load_checkpoint_with_conversion(ckpt_path)
state_dict = checkpoint["state_dict"] state_dict = checkpoint["state_dict"]
@ -926,6 +1079,8 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
for k, v in unet_state_dict.items(): for k, v in unet_state_dict.items():
key = "model.diffusion_model." + k key = "model.diffusion_model." + k
assert key in state_dict, f"Illegal key in save SD: {key}" assert key in state_dict, f"Illegal key in save SD: {key}"
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v state_dict[key] = v
# Convert the text encoder model # Convert the text encoder model
@ -933,6 +1088,8 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
for k, v in text_enc_dict.items(): for k, v in text_enc_dict.items():
key = "cond_stage_model.transformer." + k key = "cond_stage_model.transformer." + k
assert key in state_dict, f"Illegal key in save SD: {key}" assert key in state_dict, f"Illegal key in save SD: {key}"
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v state_dict[key] = v
# Put together new checkpoint # Put together new checkpoint
@ -951,24 +1108,7 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
def collate_fn(examples): def collate_fn(examples):
input_ids = [e['input_ids'] for e in examples] return examples[0]
input_ids = torch.stack(input_ids)
if 'latents' in examples[0]:
pixel_values = None
latents = [e['latents'] for e in examples]
latents = torch.stack(latents)
else:
pixel_values = [e['image'] for e in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
latents = None
loss_weights = [e['loss_weight'] for e in examples]
loss_weights = torch.FloatTensor(loss_weights)
batch = {"input_ids": input_ids, "pixel_values": pixel_values, "latents": latents, "loss_weights": loss_weights}
return batch
def train(args): def train(args):
@ -998,19 +1138,22 @@ def train(args):
try: try:
n_repeats = int(tokens[0]) n_repeats = int(tokens[0])
except ValueError as e: except ValueError as e:
print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}") # print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
raise e # raise e
return 0, []
caption = '_'.join(tokens[1:]) caption = '_'.join(tokens[1:])
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) print(f"found directory {n_repeats}_{caption}")
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + glob.glob(os.path.join(dir, "*.webp"))
return n_repeats, [(ip, caption) for ip in img_paths] return n_repeats, [(ip, caption) for ip in img_paths]
print("prepare train images.") print("prepare train images.")
train_img_path_captions = [] train_img_path_captions = []
if fine_tuning: if fine_tuning:
img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
for img_path in tqdm(img_paths): for img_path in tqdm(img_paths):
# captionの候補ファイル名を作る # captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0] base_name = os.path.splitext(img_path)[0]
@ -1042,7 +1185,7 @@ def train(args):
n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir)) n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
for _ in range(n_repeats): for _ in range(n_repeats):
train_img_path_captions.extend(img_caps) train_img_path_captions.extend(img_caps)
print(f"{len(train_img_path_captions)} train images.") print(f"{len(train_img_path_captions)} train images with repeating.")
reg_img_path_captions = [] reg_img_path_captions = []
if args.reg_data_dir: if args.reg_data_dir:
@ -1054,11 +1197,6 @@ def train(args):
reg_img_path_captions.extend(img_caps) reg_img_path_captions.extend(img_caps)
print(f"{len(reg_img_path_captions)} reg images.") print(f"{len(reg_img_path_captions)} reg images.")
if args.debug_dataset:
# デバッグ時はshuffleして実際のデータセット使用時に近づける学習時はdata loaderでshuffleする
random.shuffle(train_img_path_captions)
random.shuffle(reg_img_path_captions)
# データセットを準備する # データセットを準備する
resolution = tuple([int(r) for r in args.resolution.split(',')]) resolution = tuple([int(r) for r in args.resolution.split(',')])
if len(resolution) == 1: if len(resolution) == 1:
@ -1078,29 +1216,40 @@ def train(args):
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
print("prepare dataset") print("prepare dataset")
train_dataset = DreamBoothOrFineTuningDataset(fine_tuning, train_img_path_captions, train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution,
reg_img_path_captions, tokenizer, resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.shuffle_caption, args.no_token_padding, args.debug_dataset) args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop,
args.shuffle_caption, args.no_token_padding, args.debug_dataset)
if args.debug_dataset: if args.debug_dataset:
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}") train_dataset.make_buckets_with_caching(args.enable_bucket, None) # デバッグ用にcacheなしで作る
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します") print("Escape for exit. / Escキーで中断、終了します")
for example in train_dataset: for example in train_dataset:
im = example['image'] for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']):
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV) im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
print(f'caption: "{example["caption"]}", loss weight: {example["loss_weight"]}') print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}')
cv2.imshow("img", im) cv2.imshow("img", im)
k = cv2.waitKey() k = cv2.waitKey()
cv2.destroyAllWindows() cv2.destroyAllWindows()
if k == 27: if k == 27:
break break
if k == 27:
break
return return
# acceleratorを準備する # acceleratorを準備する
# gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
print("prepare accelerator") print("prepare accelerator")
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision) if args.logging_dir is None:
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision,
log_with=log_with, logging_dir=logging_dir)
# モデルを読み込む # モデルを読み込む
if use_stable_diffusion_format: if use_stable_diffusion_format:
@ -1122,28 +1271,24 @@ def train(args):
elif args.mixed_precision == "bf16": elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 weight_dtype = torch.bfloat16
save_dtype = None
if args.save_precision == "fp16":
save_dtype = torch.float16
elif args.save_precision == "bf16":
save_dtype = torch.bfloat16
elif args.save_precision == "float":
save_dtype = torch.float32
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
# latentをcacheする→新しいDatasetを作るとcaptionのshuffleが効かないので元のDatasetにcacheを持つcascadeする手もあるが
print("caching latents.")
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
for i in tqdm(range(len(train_dataset))):
example = train_dataset[i]
if 'latents' not in example:
image_path = example['image_path']
with torch.no_grad(): with torch.no_grad():
pixel_values = example["image"].unsqueeze(0).to(device=accelerator.device, dtype=weight_dtype) train_dataset.make_buckets_with_caching(args.enable_bucket, vae)
latents = vae.encode(pixel_values).latent_dist.sample().squeeze(0).to("cpu")
train_dataset.set_cached_latents(image_path, latents)
# assertion
for i in range(len(train_dataset)):
assert 'latents' in train_dataset[i], "internal error: latents not cached"
del vae del vae
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
else: else:
train_dataset.make_buckets_with_caching(args.enable_bucket, None)
vae.requires_grad_(False) vae.requires_grad_(False)
if args.gradient_checkpointing: if args.gradient_checkpointing:
@ -1173,7 +1318,7 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=n_workers) train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps) lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
@ -1185,6 +1330,11 @@ def train(args):
if not cache_latents: if not cache_latents:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する # epoch数を計算する
num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader)) num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
@ -1193,7 +1343,7 @@ def train(args):
print("running training / 学習開始") print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
print(f" num examples / サンプル数: {len(train_dataset)}") print(f" num examples / サンプル数: {train_dataset.num_train_images * 2}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}") print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@ -1222,7 +1372,7 @@ def train(args):
if cache_latents: if cache_latents:
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)
else: else:
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * 0.18215
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
@ -1271,15 +1421,22 @@ def train(args):
global_step += 1 global_step += 1
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
loss_total += current_loss loss_total += current_loss
avr_loss = loss_total / (step+1) avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
# accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if use_stable_diffusion_format and args.save_every_n_epochs is not None: if use_stable_diffusion_format and args.save_every_n_epochs is not None:
@ -1288,7 +1445,11 @@ def train(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1)) ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet), save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
args.pretrained_model_name_or_path, epoch + 1, global_step) args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
if args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:
@ -1296,6 +1457,11 @@ def train(args):
text_encoder = accelerator.unwrap_model(text_encoder) text_encoder = accelerator.unwrap_model(text_encoder)
accelerator.end_training() accelerator.end_training()
if args.save_state:
print("saving last state.")
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
del accelerator # この後メモリを使うのでこれは消す del accelerator # この後メモリを使うのでこれは消す
if is_main_process: if is_main_process:
@ -1303,7 +1469,8 @@ def train(args):
if use_stable_diffusion_format: if use_stable_diffusion_format:
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME) ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step) save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet,
args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
else: else:
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
print(f"save trained model as Diffusers to {args.output_dir}") print(f"save trained model as Diffusers to {args.output_dir}")
@ -1589,6 +1756,10 @@ if __name__ == '__main__':
help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)") help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
parser.add_argument("--save_every_n_epochs", type=int, default=None, parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存しますStableDiffusion形式のモデルを読み込んだ場合のみ有効") help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存しますStableDiffusion形式のモデルを読み込んだ場合のみ有効")
parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None,
help="saved state to resume training / 学習再開するモデルのstate")
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
parser.add_argument("--no_token_padding", action="store_true", parser.add_argument("--no_token_padding", action="store_true",
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作") help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作")
@ -1612,6 +1783,8 @@ if __name__ == '__main__':
help="use xformers for CrossAttention / CrossAttentionにxformersを使う") help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument("--cache_latents", action="store_true", parser.add_argument("--cache_latents", action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可") help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可")
parser.add_argument("--enable_bucket", action="store_true",
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
@ -1619,8 +1792,12 @@ if __name__ == '__main__':
help="enable gradient checkpointing / grandient checkpointingを有効にする") help="enable gradient checkpointing / grandient checkpointingを有効にする")
parser.add_argument("--mixed_precision", type=str, default="no", parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
parser.add_argument("--clip_skip", type=int, default=None, parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上") help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)