Update to v10
This commit is contained in:
parent
a6c7bb06dc
commit
2629617de7
@ -135,3 +135,8 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n
|
||||
- Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision)
|
||||
- Added support for saving learning state (--save_state, --resume)
|
||||
- Added support for logging (--logging_dir)
|
||||
* 11/21 (v10):
|
||||
- Added minimum/maximum resolution specification when using Aspect Ratio Bucketing (min_bucket_reso/max_bucket_reso option).
|
||||
- Added extension specification for caption files (caption_extention).
|
||||
- Added support for images with .webp extension.
|
||||
- Added a function that allows captions to learning images and regularized images.
|
@ -6,8 +6,9 @@ $folder = "D:\some\folder\location\"
|
||||
$file_pattern="*.*"
|
||||
$caption_text="some caption text"
|
||||
|
||||
$files = Get-ChildItem $folder$file_pattern -Include *.png,*.jpg,*.webp -File
|
||||
foreach ($file in $files)
|
||||
{
|
||||
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text
|
||||
$files = Get-ChildItem $folder$file_pattern -Include *.png, *.jpg, *.webp -File
|
||||
foreach ($file in $files) {
|
||||
if (-not(Test-Path -Path $folder\"$($file.BaseName).txt" -PathType Leaf)) {
|
||||
New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@
|
||||
# enable reg images in fine-tuning, add dataset_repeats option
|
||||
# v8: supports Diffusers 0.7.2
|
||||
# v9: add bucketing option
|
||||
# v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth
|
||||
|
||||
import time
|
||||
from torch.autograd.function import Function
|
||||
@ -143,7 +144,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
# bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
||||
def make_buckets_with_caching(self, enable_bucket, vae):
|
||||
def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size):
|
||||
self.enable_bucket = enable_bucket
|
||||
|
||||
cache_latents = vae is not None
|
||||
@ -160,7 +161,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
||||
|
||||
# bucketingを用意する
|
||||
if enable_bucket:
|
||||
bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height))
|
||||
bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height), min_size, max_size)
|
||||
else:
|
||||
# bucketはひとつだけ、すべての画像は同じ解像度
|
||||
bucket_resos = [(self.width, self.height)]
|
||||
@ -387,7 +388,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
||||
latents_list.append(latents)
|
||||
|
||||
# captionを処理する
|
||||
if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする
|
||||
if self.shuffle_caption: # captionのshuffleをする
|
||||
tokens = caption.strip().split(",")
|
||||
random.shuffle(tokens)
|
||||
caption = ",".join(tokens).strip()
|
||||
@ -400,7 +401,7 @@ class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
# paddingする
|
||||
input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
|
||||
|
||||
|
||||
example = {}
|
||||
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
||||
example['input_ids'] = input_ids
|
||||
@ -1133,44 +1134,55 @@ def train(args):
|
||||
set_seed(args.seed)
|
||||
|
||||
# 学習データを用意する
|
||||
def read_caption(img_path):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
tokens = base_name.split("_")
|
||||
if len(tokens) >= 5:
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [base_name + args.caption_extention, base_name_face_det + args.caption_extention]
|
||||
|
||||
caption = None
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding='utf-8') as f:
|
||||
caption = f.readlines()[0].strip()
|
||||
break
|
||||
return caption
|
||||
|
||||
def load_dreambooth_dir(dir):
|
||||
tokens = os.path.basename(dir).split('_')
|
||||
try:
|
||||
n_repeats = int(tokens[0])
|
||||
except ValueError as e:
|
||||
# print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
|
||||
# raise e
|
||||
return 0, []
|
||||
|
||||
caption = '_'.join(tokens[1:])
|
||||
|
||||
print(f"found directory {n_repeats}_{caption}")
|
||||
|
||||
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + glob.glob(os.path.join(dir, "*.webp"))
|
||||
return n_repeats, [(ip, caption) for ip in img_paths]
|
||||
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \
|
||||
glob.glob(os.path.join(dir, "*.webp"))
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあれば連結する
|
||||
captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path)
|
||||
captions.append(caption + ("" if cap_for_img is None else cap_for_img))
|
||||
|
||||
return n_repeats, list(zip(img_paths, captions))
|
||||
|
||||
print("prepare train images.")
|
||||
train_img_path_captions = []
|
||||
|
||||
if fine_tuning:
|
||||
img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
for img_path in tqdm(img_paths):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
tokens = base_name.split("_")
|
||||
if len(tokens) >= 5:
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [base_name + '.txt', base_name + '.caption', base_name_face_det+'.txt', base_name_face_det+'.caption']
|
||||
|
||||
caption = None
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding='utf-8') as f:
|
||||
caption = f.readlines()[0].strip()
|
||||
break
|
||||
|
||||
assert caption is not None and len(caption) > 0, f"no caption / キャプションファイルが見つからないか、captionが空です: {cap_paths}"
|
||||
caption = read_caption(img_path)
|
||||
assert caption is not None and len(
|
||||
caption) > 0, f"no caption for image. check caption_extention option / キャプションファイルが見つからないかcaptionが空です。caption_extentionオプションを確認してください: {img_path}"
|
||||
|
||||
train_img_path_captions.append((img_path, caption))
|
||||
|
||||
@ -1201,8 +1213,12 @@ def train(args):
|
||||
resolution = tuple([int(r) for r in args.resolution.split(',')])
|
||||
if len(resolution) == 1:
|
||||
resolution = (resolution[0], resolution[0])
|
||||
assert len(
|
||||
resolution) == 2, f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
||||
assert len(resolution) == 2, \
|
||||
f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
||||
|
||||
if args.enable_bucket:
|
||||
assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or greater than resolution / min_bucket_resoは解像度の数値以上で指定してください"
|
||||
assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or less than resolution / max_bucket_resoは解像度の数値以下で指定してください"
|
||||
|
||||
if args.face_crop_aug_range is not None:
|
||||
face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
|
||||
@ -1221,7 +1237,8 @@ def train(args):
|
||||
args.shuffle_caption, args.no_token_padding, args.debug_dataset)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_dataset.make_buckets_with_caching(args.enable_bucket, None) # デバッグ用にcacheなしで作る
|
||||
train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso,
|
||||
args.max_bucket_reso) # デバッグ用にcacheなしで作る
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
for example in train_dataset:
|
||||
@ -1282,14 +1299,20 @@ def train(args):
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset.make_buckets_with_caching(args.enable_bucket, vae)
|
||||
train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso)
|
||||
del vae
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
train_dataset.make_buckets_with_caching(args.enable_bucket, None)
|
||||
train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
unet.requires_grad_(True) # 念のため追加
|
||||
text_encoder.requires_grad_(True)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
@ -1343,7 +1366,7 @@ def train(args):
|
||||
print("running training / 学習開始")
|
||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
||||
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
||||
print(f" num examples / サンプル数: {train_dataset.num_train_images * 2}")
|
||||
print(f" num examples / サンプル数: {train_dataset.num_train_images * (2 if train_dataset.enable_reg_images else 1)}")
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
@ -1457,7 +1480,7 @@ def train(args):
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if args.save_state:
|
||||
print("saving last state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
|
||||
@ -1720,11 +1743,13 @@ def replace_unet_cross_attn_to_xformers():
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format
|
||||
# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format
|
||||
del q_in, k_in, v_in
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
# out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
# diffusers 0.6.0
|
||||
if type(self.to_out) is torch.nn.Sequential:
|
||||
@ -1747,7 +1772,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--fine_tuning", action="store_true",
|
||||
help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
|
||||
parser.add_argument("--shuffle_caption", action="store_true",
|
||||
help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする")
|
||||
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
|
||||
parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption files / 読み込むcaptionファイルの拡張子")
|
||||
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
||||
parser.add_argument("--dataset_repeats", type=int, default=None,
|
||||
@ -1758,8 +1784,7 @@ if __name__ == '__main__':
|
||||
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
|
||||
parser.add_argument("--save_state", action="store_true",
|
||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||
parser.add_argument("--resume", type=str, default=None,
|
||||
help="saved state to resume training / 学習再開するモデルのstate")
|
||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
|
||||
parser.add_argument("--no_token_padding", action="store_true",
|
||||
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
||||
@ -1785,6 +1810,8 @@ if __name__ == '__main__':
|
||||
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
||||
parser.add_argument("--enable_bucket", action="store_true",
|
||||
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
|
Loading…
Reference in New Issue
Block a user