Update to v10
This commit is contained in:
parent
a6c7bb06dc
commit
2629617de7
@ -135,3 +135,8 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n
|
|||||||
- Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision)
|
- Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision)
|
||||||
- Added support for saving learning state (--save_state, --resume)
|
- Added support for saving learning state (--save_state, --resume)
|
||||||
- Added support for logging (--logging_dir)
|
- Added support for logging (--logging_dir)
|
||||||
|
* 11/21 (v10):
|
||||||
|
- Added minimum/maximum resolution specification when using Aspect Ratio Bucketing (min_bucket_reso/max_bucket_reso option).
|
||||||
|
- Added extension specification for caption files (caption_extention).
|
||||||
|
- Added support for images with .webp extension.
|
||||||
|
- Added a function that allows captions to learning images and regularized images.
|
@ -6,8 +6,9 @@ $folder = "D:\some\folder\location\"
|
|||||||
$file_pattern="*.*"
|
$file_pattern="*.*"
|
||||||
$caption_text="some caption text"
|
$caption_text="some caption text"
|
||||||
|
|
||||||
$files = Get-ChildItem $folder$file_pattern -Include *.png,*.jpg,*.webp -File
|
$files = Get-ChildItem $folder$file_pattern -Include *.png, *.jpg, *.webp -File
|
||||||
foreach ($file in $files)
|
foreach ($file in $files) {
|
||||||
{
|
if (-not(Test-Path -Path $folder\"$($file.BaseName).txt" -PathType Leaf)) {
|
||||||
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text
|
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text
|
||||||
|
}
|
||||||
}
|
}
|
@ -5,6 +5,7 @@
|
|||||||
# enable reg images in fine-tuning, add dataset_repeats option
|
# enable reg images in fine-tuning, add dataset_repeats option
|
||||||
# v8: supports Diffusers 0.7.2
|
# v8: supports Diffusers 0.7.2
|
||||||
# v9: add bucketing option
|
# v9: add bucketing option
|
||||||
|
# v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from torch.autograd.function import Function
|
from torch.autograd.function import Function
|
||||||
@ -143,7 +144,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
# bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
||||||
def make_buckets_with_caching(self, enable_bucket, vae):
|
def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size):
|
||||||
self.enable_bucket = enable_bucket
|
self.enable_bucket = enable_bucket
|
||||||
|
|
||||||
cache_latents = vae is not None
|
cache_latents = vae is not None
|
||||||
@ -160,7 +161,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# bucketingを用意する
|
# bucketingを用意する
|
||||||
if enable_bucket:
|
if enable_bucket:
|
||||||
bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height))
|
bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height), min_size, max_size)
|
||||||
else:
|
else:
|
||||||
# bucketはひとつだけ、すべての画像は同じ解像度
|
# bucketはひとつだけ、すべての画像は同じ解像度
|
||||||
bucket_resos = [(self.width, self.height)]
|
bucket_resos = [(self.width, self.height)]
|
||||||
@ -387,7 +388,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
latents_list.append(latents)
|
latents_list.append(latents)
|
||||||
|
|
||||||
# captionを処理する
|
# captionを処理する
|
||||||
if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする
|
if self.shuffle_caption: # captionのshuffleをする
|
||||||
tokens = caption.strip().split(",")
|
tokens = caption.strip().split(",")
|
||||||
random.shuffle(tokens)
|
random.shuffle(tokens)
|
||||||
caption = ",".join(tokens).strip()
|
caption = ",".join(tokens).strip()
|
||||||
@ -400,7 +401,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
|||||||
else:
|
else:
|
||||||
# paddingする
|
# paddingする
|
||||||
input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
|
input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
|
||||||
|
|
||||||
example = {}
|
example = {}
|
||||||
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
||||||
example['input_ids'] = input_ids
|
example['input_ids'] = input_ids
|
||||||
@ -1133,44 +1134,55 @@ def train(args):
|
|||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
# 学習データを用意する
|
# 学習データを用意する
|
||||||
|
def read_caption(img_path):
|
||||||
|
# captionの候補ファイル名を作る
|
||||||
|
base_name = os.path.splitext(img_path)[0]
|
||||||
|
base_name_face_det = base_name
|
||||||
|
tokens = base_name.split("_")
|
||||||
|
if len(tokens) >= 5:
|
||||||
|
base_name_face_det = "_".join(tokens[:-4])
|
||||||
|
cap_paths = [base_name + args.caption_extention, base_name_face_det + args.caption_extention]
|
||||||
|
|
||||||
|
caption = None
|
||||||
|
for cap_path in cap_paths:
|
||||||
|
if os.path.isfile(cap_path):
|
||||||
|
with open(cap_path, "rt", encoding='utf-8') as f:
|
||||||
|
caption = f.readlines()[0].strip()
|
||||||
|
break
|
||||||
|
return caption
|
||||||
|
|
||||||
def load_dreambooth_dir(dir):
|
def load_dreambooth_dir(dir):
|
||||||
tokens = os.path.basename(dir).split('_')
|
tokens = os.path.basename(dir).split('_')
|
||||||
try:
|
try:
|
||||||
n_repeats = int(tokens[0])
|
n_repeats = int(tokens[0])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
|
|
||||||
# raise e
|
|
||||||
return 0, []
|
return 0, []
|
||||||
|
|
||||||
caption = '_'.join(tokens[1:])
|
caption = '_'.join(tokens[1:])
|
||||||
|
|
||||||
print(f"found directory {n_repeats}_{caption}")
|
print(f"found directory {n_repeats}_{caption}")
|
||||||
|
|
||||||
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + glob.glob(os.path.join(dir, "*.webp"))
|
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \
|
||||||
return n_repeats, [(ip, caption) for ip in img_paths]
|
glob.glob(os.path.join(dir, "*.webp"))
|
||||||
|
|
||||||
|
# 画像ファイルごとにプロンプトを読み込み、もしあれば連結する
|
||||||
|
captions = []
|
||||||
|
for img_path in img_paths:
|
||||||
|
cap_for_img = read_caption(img_path)
|
||||||
|
captions.append(caption + ("" if cap_for_img is None else cap_for_img))
|
||||||
|
|
||||||
|
return n_repeats, list(zip(img_paths, captions))
|
||||||
|
|
||||||
print("prepare train images.")
|
print("prepare train images.")
|
||||||
train_img_path_captions = []
|
train_img_path_captions = []
|
||||||
|
|
||||||
if fine_tuning:
|
if fine_tuning:
|
||||||
img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
for img_path in tqdm(img_paths):
|
for img_path in tqdm(img_paths):
|
||||||
# captionの候補ファイル名を作る
|
caption = read_caption(img_path)
|
||||||
base_name = os.path.splitext(img_path)[0]
|
assert caption is not None and len(
|
||||||
base_name_face_det = base_name
|
caption) > 0, f"no caption for image. check caption_extention option / キャプションファイルが見つからないかcaptionが空です。caption_extentionオプションを確認してください: {img_path}"
|
||||||
tokens = base_name.split("_")
|
|
||||||
if len(tokens) >= 5:
|
|
||||||
base_name_face_det = "_".join(tokens[:-4])
|
|
||||||
cap_paths = [base_name + '.txt', base_name + '.caption', base_name_face_det+'.txt', base_name_face_det+'.caption']
|
|
||||||
|
|
||||||
caption = None
|
|
||||||
for cap_path in cap_paths:
|
|
||||||
if os.path.isfile(cap_path):
|
|
||||||
with open(cap_path, "rt", encoding='utf-8') as f:
|
|
||||||
caption = f.readlines()[0].strip()
|
|
||||||
break
|
|
||||||
|
|
||||||
assert caption is not None and len(caption) > 0, f"no caption / キャプションファイルが見つからないか、captionが空です: {cap_paths}"
|
|
||||||
|
|
||||||
train_img_path_captions.append((img_path, caption))
|
train_img_path_captions.append((img_path, caption))
|
||||||
|
|
||||||
@ -1201,8 +1213,12 @@ def train(args):
|
|||||||
resolution = tuple([int(r) for r in args.resolution.split(',')])
|
resolution = tuple([int(r) for r in args.resolution.split(',')])
|
||||||
if len(resolution) == 1:
|
if len(resolution) == 1:
|
||||||
resolution = (resolution[0], resolution[0])
|
resolution = (resolution[0], resolution[0])
|
||||||
assert len(
|
assert len(resolution) == 2, \
|
||||||
resolution) == 2, f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
||||||
|
|
||||||
|
if args.enable_bucket:
|
||||||
|
assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or greater than resolution / min_bucket_resoは解像度の数値以上で指定してください"
|
||||||
|
assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or less than resolution / max_bucket_resoは解像度の数値以下で指定してください"
|
||||||
|
|
||||||
if args.face_crop_aug_range is not None:
|
if args.face_crop_aug_range is not None:
|
||||||
face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
|
face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
|
||||||
@ -1221,7 +1237,8 @@ def train(args):
|
|||||||
args.shuffle_caption, args.no_token_padding, args.debug_dataset)
|
args.shuffle_caption, args.no_token_padding, args.debug_dataset)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
train_dataset.make_buckets_with_caching(args.enable_bucket, None) # デバッグ用にcacheなしで作る
|
train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso,
|
||||||
|
args.max_bucket_reso) # デバッグ用にcacheなしで作る
|
||||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||||
print("Escape for exit. / Escキーで中断、終了します")
|
print("Escape for exit. / Escキーで中断、終了します")
|
||||||
for example in train_dataset:
|
for example in train_dataset:
|
||||||
@ -1282,14 +1299,20 @@ def train(args):
|
|||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset.make_buckets_with_caching(args.enable_bucket, vae)
|
train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso)
|
||||||
del vae
|
del vae
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
else:
|
else:
|
||||||
train_dataset.make_buckets_with_caching(args.enable_bucket, None)
|
train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
|
||||||
|
unet.requires_grad_(True) # 念のため追加
|
||||||
|
text_encoder.requires_grad_(True)
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
@ -1343,7 +1366,7 @@ def train(args):
|
|||||||
print("running training / 学習開始")
|
print("running training / 学習開始")
|
||||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
||||||
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
||||||
print(f" num examples / サンプル数: {train_dataset.num_train_images * 2}")
|
print(f" num examples / サンプル数: {train_dataset.num_train_images * (2 if train_dataset.enable_reg_images else 1)}")
|
||||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
@ -1457,7 +1480,7 @@ def train(args):
|
|||||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if args.save_state:
|
if args.save_state:
|
||||||
print("saving last state.")
|
print("saving last state.")
|
||||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
|
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
|
||||||
@ -1720,11 +1743,13 @@ def replace_unet_cross_attn_to_xformers():
|
|||||||
k_in = self.to_k(context)
|
k_in = self.to_k(context)
|
||||||
v_in = self.to_v(context)
|
v_in = self.to_v(context)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format
|
||||||
|
# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||||
|
|
||||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||||
|
# out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
# diffusers 0.6.0
|
# diffusers 0.6.0
|
||||||
if type(self.to_out) is torch.nn.Sequential:
|
if type(self.to_out) is torch.nn.Sequential:
|
||||||
@ -1747,7 +1772,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--fine_tuning", action="store_true",
|
parser.add_argument("--fine_tuning", action="store_true",
|
||||||
help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
|
help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
|
||||||
parser.add_argument("--shuffle_caption", action="store_true",
|
parser.add_argument("--shuffle_caption", action="store_true",
|
||||||
help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする")
|
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
|
||||||
|
parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption files / 読み込むcaptionファイルの拡張子")
|
||||||
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
||||||
parser.add_argument("--dataset_repeats", type=int, default=None,
|
parser.add_argument("--dataset_repeats", type=int, default=None,
|
||||||
@ -1758,8 +1784,7 @@ if __name__ == '__main__':
|
|||||||
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
|
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
|
||||||
parser.add_argument("--save_state", action="store_true",
|
parser.add_argument("--save_state", action="store_true",
|
||||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||||
parser.add_argument("--resume", type=str, default=None,
|
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||||
help="saved state to resume training / 学習再開するモデルのstate")
|
|
||||||
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
|
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
|
||||||
parser.add_argument("--no_token_padding", action="store_true",
|
parser.add_argument("--no_token_padding", action="store_true",
|
||||||
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
||||||
@ -1785,6 +1810,8 @@ if __name__ == '__main__':
|
|||||||
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
||||||
parser.add_argument("--enable_bucket", action="store_true",
|
parser.add_argument("--enable_bucket", action="store_true",
|
||||||
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
|
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
|
||||||
|
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||||
|
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||||
|
Loading…
Reference in New Issue
Block a user