Update to v10

This commit is contained in:
Bernard Maltais 2022-11-21 07:50:04 -05:00
parent a6c7bb06dc
commit 2629617de7
3 changed files with 73 additions and 40 deletions

View File

@ -135,3 +135,8 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n
- Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision) - Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision)
- Added support for saving learning state (--save_state, --resume) - Added support for saving learning state (--save_state, --resume)
- Added support for logging (--logging_dir) - Added support for logging (--logging_dir)
* 11/21 (v10):
- Added minimum/maximum resolution specification when using Aspect Ratio Bucketing (min_bucket_reso/max_bucket_reso option).
- Added extension specification for caption files (caption_extention).
- Added support for images with .webp extension.
- Added a function that allows captions to learning images and regularized images.

View File

@ -6,8 +6,9 @@ $folder = "D:\some\folder\location\"
$file_pattern="*.*" $file_pattern="*.*"
$caption_text="some caption text" $caption_text="some caption text"
$files = Get-ChildItem $folder$file_pattern -Include *.png,*.jpg,*.webp -File $files = Get-ChildItem $folder$file_pattern -Include *.png, *.jpg, *.webp -File
foreach ($file in $files) foreach ($file in $files) {
{ if (-not(Test-Path -Path $folder\"$($file.BaseName).txt" -PathType Leaf)) {
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text
}
} }

View File

@ -5,6 +5,7 @@
# enable reg images in fine-tuning, add dataset_repeats option # enable reg images in fine-tuning, add dataset_repeats option
# v8: supports Diffusers 0.7.2 # v8: supports Diffusers 0.7.2
# v9: add bucketing option # v9: add bucketing option
# v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth
import time import time
from torch.autograd.function import Function from torch.autograd.function import Function
@ -143,7 +144,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
) )
# bucketingを行わない場合も呼び出し必須ひとつだけbucketを作る # bucketingを行わない場合も呼び出し必須ひとつだけbucketを作る
def make_buckets_with_caching(self, enable_bucket, vae): def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size):
self.enable_bucket = enable_bucket self.enable_bucket = enable_bucket
cache_latents = vae is not None cache_latents = vae is not None
@ -160,7 +161,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
# bucketingを用意する # bucketingを用意する
if enable_bucket: if enable_bucket:
bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height)) bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height), min_size, max_size)
else: else:
# bucketはひとつだけ、すべての画像は同じ解像度 # bucketはひとつだけ、すべての画像は同じ解像度
bucket_resos = [(self.width, self.height)] bucket_resos = [(self.width, self.height)]
@ -387,7 +388,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
latents_list.append(latents) latents_list.append(latents)
# captionを処理する # captionを処理する
if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする if self.shuffle_caption: # captionのshuffleをする
tokens = caption.strip().split(",") tokens = caption.strip().split(",")
random.shuffle(tokens) random.shuffle(tokens)
caption = ",".join(tokens).strip() caption = ",".join(tokens).strip()
@ -400,7 +401,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
else: else:
# paddingする # paddingする
input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
example = {} example = {}
example['loss_weights'] = torch.FloatTensor(loss_weights) example['loss_weights'] = torch.FloatTensor(loss_weights)
example['input_ids'] = input_ids example['input_ids'] = input_ids
@ -1133,44 +1134,55 @@ def train(args):
set_seed(args.seed) set_seed(args.seed)
# 学習データを用意する # 学習データを用意する
def read_caption(img_path):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + args.caption_extention, base_name_face_det + args.caption_extention]
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding='utf-8') as f:
caption = f.readlines()[0].strip()
break
return caption
def load_dreambooth_dir(dir): def load_dreambooth_dir(dir):
tokens = os.path.basename(dir).split('_') tokens = os.path.basename(dir).split('_')
try: try:
n_repeats = int(tokens[0]) n_repeats = int(tokens[0])
except ValueError as e: except ValueError as e:
# print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
# raise e
return 0, [] return 0, []
caption = '_'.join(tokens[1:]) caption = '_'.join(tokens[1:])
print(f"found directory {n_repeats}_{caption}") print(f"found directory {n_repeats}_{caption}")
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + glob.glob(os.path.join(dir, "*.webp")) img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \
return n_repeats, [(ip, caption) for ip in img_paths] glob.glob(os.path.join(dir, "*.webp"))
# 画像ファイルごとにプロンプトを読み込み、もしあれば連結する
captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path)
captions.append(caption + ("" if cap_for_img is None else cap_for_img))
return n_repeats, list(zip(img_paths, captions))
print("prepare train images.") print("prepare train images.")
train_img_path_captions = [] train_img_path_captions = []
if fine_tuning: if fine_tuning:
img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + \
glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
for img_path in tqdm(img_paths): for img_path in tqdm(img_paths):
# captionの候補ファイル名を作る caption = read_caption(img_path)
base_name = os.path.splitext(img_path)[0] assert caption is not None and len(
base_name_face_det = base_name caption) > 0, f"no caption for image. check caption_extention option / キャプションファイルが見つからないかcaptionが空です。caption_extentionオプションを確認してください: {img_path}"
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + '.txt', base_name + '.caption', base_name_face_det+'.txt', base_name_face_det+'.caption']
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding='utf-8') as f:
caption = f.readlines()[0].strip()
break
assert caption is not None and len(caption) > 0, f"no caption / キャプションファイルが見つからないか、captionが空です: {cap_paths}"
train_img_path_captions.append((img_path, caption)) train_img_path_captions.append((img_path, caption))
@ -1201,8 +1213,12 @@ def train(args):
resolution = tuple([int(r) for r in args.resolution.split(',')]) resolution = tuple([int(r) for r in args.resolution.split(',')])
if len(resolution) == 1: if len(resolution) == 1:
resolution = (resolution[0], resolution[0]) resolution = (resolution[0], resolution[0])
assert len( assert len(resolution) == 2, \
resolution) == 2, f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'','高さ'で指定してください: {args.resolution}" f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'','高さ'で指定してください: {args.resolution}"
if args.enable_bucket:
assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or greater than resolution / min_bucket_resoは解像度の数値以上で指定してください"
assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or less than resolution / max_bucket_resoは解像度の数値以下で指定してください"
if args.face_crop_aug_range is not None: if args.face_crop_aug_range is not None:
face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
@ -1221,7 +1237,8 @@ def train(args):
args.shuffle_caption, args.no_token_padding, args.debug_dataset) args.shuffle_caption, args.no_token_padding, args.debug_dataset)
if args.debug_dataset: if args.debug_dataset:
train_dataset.make_buckets_with_caching(args.enable_bucket, None) # デバッグ用にcacheなしで作る train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso,
args.max_bucket_reso) # デバッグ用にcacheなしで作る
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します") print("Escape for exit. / Escキーで中断、終了します")
for example in train_dataset: for example in train_dataset:
@ -1282,14 +1299,20 @@ def train(args):
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset.make_buckets_with_caching(args.enable_bucket, vae) train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso)
del vae del vae
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
else: else:
train_dataset.make_buckets_with_caching(args.enable_bucket, None) train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval()
unet.requires_grad_(True) # 念のため追加
text_encoder.requires_grad_(True)
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
@ -1343,7 +1366,7 @@ def train(args):
print("running training / 学習開始") print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
print(f" num examples / サンプル数: {train_dataset.num_train_images * 2}") print(f" num examples / サンプル数: {train_dataset.num_train_images * (2 if train_dataset.enable_reg_images else 1)}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}") print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@ -1457,7 +1480,7 @@ def train(args):
text_encoder = accelerator.unwrap_model(text_encoder) text_encoder = accelerator.unwrap_model(text_encoder)
accelerator.end_training() accelerator.end_training()
if args.save_state: if args.save_state:
print("saving last state.") print("saving last state.")
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
@ -1720,11 +1743,13 @@ def replace_unet_cross_attn_to_xformers():
k_in = self.to_k(context) k_in = self.to_k(context)
v_in = self.to_v(context) v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format
# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format
del q_in, k_in, v_in del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
out = rearrange(out, 'b n h d -> b n (h d)', h=h) out = rearrange(out, 'b n h d -> b n (h d)', h=h)
# out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
# diffusers 0.6.0 # diffusers 0.6.0
if type(self.to_out) is torch.nn.Sequential: if type(self.to_out) is torch.nn.Sequential:
@ -1747,7 +1772,8 @@ if __name__ == '__main__':
parser.add_argument("--fine_tuning", action="store_true", parser.add_argument("--fine_tuning", action="store_true",
help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする") help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
parser.add_argument("--shuffle_caption", action="store_true", parser.add_argument("--shuffle_caption", action="store_true",
help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする") help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption files / 読み込むcaptionファイルの拡張子")
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
parser.add_argument("--dataset_repeats", type=int, default=None, parser.add_argument("--dataset_repeats", type=int, default=None,
@ -1758,8 +1784,7 @@ if __name__ == '__main__':
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存しますStableDiffusion形式のモデルを読み込んだ場合のみ有効") help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存しますStableDiffusion形式のモデルを読み込んだ場合のみ有効")
parser.add_argument("--save_state", action="store_true", parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
help="saved state to resume training / 学習再開するモデルのstate")
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
parser.add_argument("--no_token_padding", action="store_true", parser.add_argument("--no_token_padding", action="store_true",
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作") help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作")
@ -1785,6 +1810,8 @@ if __name__ == '__main__':
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可") help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可")
parser.add_argument("--enable_bucket", action="store_true", parser.add_argument("--enable_bucket", action="store_true",
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")