Adding support for caption dropout

This commit is contained in:
bmaltais 2023-02-07 20:58:35 -05:00
parent 8d559ded18
commit 09d3a72cd8
12 changed files with 121 additions and 21 deletions

View File

@ -88,6 +88,7 @@ def save_configuration(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -177,6 +178,7 @@ def open_configuration(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -250,6 +252,7 @@ def train_model(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -416,6 +419,8 @@ def train_model(
bucket_no_upscale=bucket_no_upscale, bucket_no_upscale=bucket_no_upscale,
random_crop=random_crop, random_crop=random_crop,
bucket_reso_steps=bucket_reso_steps, bucket_reso_steps=bucket_reso_steps,
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
caption_dropout_rate=caption_dropout_rate,
) )
print(run_cmd) print(run_cmd)
@ -627,6 +632,7 @@ def dreambooth_tab(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -695,6 +701,7 @@ def dreambooth_tab(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
] ]
button_open_config.click( button_open_config.click(

View File

@ -36,6 +36,10 @@ def train(args):
args.bucket_reso_steps, args.bucket_no_upscale, args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset) args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
train_dataset.make_buckets() train_dataset.make_buckets()
if args.debug_dataset: if args.debug_dataset:
@ -226,6 +230,9 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.epoch_current = epoch + 1
for m in training_models: for m in training_models:
m.train() m.train()
@ -332,7 +339,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True) train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False) train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)

View File

@ -84,6 +84,7 @@ def save_configuration(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -179,6 +180,7 @@ def open_config_file(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -259,6 +261,7 @@ def train_model(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
# create caption json file # create caption json file
if generate_caption_database: if generate_caption_database:
@ -405,6 +408,8 @@ def train_model(
bucket_no_upscale=bucket_no_upscale, bucket_no_upscale=bucket_no_upscale,
random_crop=random_crop, random_crop=random_crop,
bucket_reso_steps=bucket_reso_steps, bucket_reso_steps=bucket_reso_steps,
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
caption_dropout_rate=caption_dropout_rate,
) )
print(run_cmd) print(run_cmd)
@ -614,6 +619,7 @@ def finetune_tab():
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -678,6 +684,7 @@ def finetune_tab():
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
] ]
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)

View File

@ -563,6 +563,15 @@ def gradio_advanced_training():
random_crop = gr.Checkbox( random_crop = gr.Checkbox(
label='Random crop instead of center crop', value=False label='Random crop instead of center crop', value=False
) )
with gr.Row():
caption_dropout_every_n_epochs = gr.Number(
label="Dropout caption every n epochs",
value=0
)
caption_dropout_rate = gr.Number(
label="Rate of caption dropout",
value=0
)
with gr.Row(): with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False) save_state = gr.Checkbox(label='Save training state', value=False)
resume = gr.Textbox( resume = gr.Textbox(
@ -599,6 +608,7 @@ def gradio_advanced_training():
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
) )
@ -622,6 +632,12 @@ def run_cmd_advanced_training(**kwargs):
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"' f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
if int(kwargs.get('keep_tokens', 0)) > 0 if int(kwargs.get('keep_tokens', 0)) > 0
else '', else '',
f' --caption_dropout_every_n_epochs="{kwargs.get("caption_dropout_every_n_epochs", "")}"'
if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
else '',
f' --caption_dropout_rate="{kwargs.get("caption_dropout_rate", "")}"'
if float(kwargs.get('caption_dropout_rate', 0)) > 0
else '',
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}' f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
if int(kwargs.get('bucket_reso_steps', 64)) >= 1 if int(kwargs.get('bucket_reso_steps', 64)) >= 1

View File

@ -113,7 +113,7 @@ class BucketManager():
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
self.predefined_resos = resos.copy() self.predefined_resos = resos.copy()
self.predefined_resos_set = set(resos) self.predefined_resos_set = set(resos)
self.predifined_aspect_ratios = np.array([w / h for w, h in resos]) self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
def add_if_new_reso(self, reso): def add_if_new_reso(self, reso):
if reso not in self.reso_to_id: if reso not in self.reso_to_id:
@ -135,7 +135,7 @@ class BucketManager():
if reso in self.predefined_resos_set: if reso in self.predefined_resos_set:
pass pass
else: else:
ar_errors = self.predifined_aspect_ratios - aspect_ratio ar_errors = self.predefined_aspect_ratios - aspect_ratio
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
reso = self.predefined_resos[predefined_bucket_id] reso = self.predefined_resos[predefined_bucket_id]
@ -223,6 +223,11 @@ class BaseDataset(torch.utils.data.Dataset):
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
# TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう
self.epoch_current: int = int(0)
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None
# augmentation # augmentation
flip_p = 0.5 if flip_aug else 0.0 flip_p = 0.5 if flip_aug else 0.0
if color_aug: if color_aug:
@ -247,6 +252,12 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {} self.replacements = {}
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs):
# 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく
# コンストラクタで渡さないのはTextual Inversionで意識したくないからということにしておく
self.dropout_rate = dropout_rate
self.dropout_every_n_epochs = dropout_every_n_epochs
def set_tag_frequency(self, dir_name, captions): def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {}) frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir self.tag_frequency[dir_name] = frequency_for_dir
@ -265,7 +276,7 @@ class BaseDataset(torch.utils.data.Dataset):
def process_caption(self, caption): def process_caption(self, caption):
if self.shuffle_caption: if self.shuffle_caption:
tokens = caption.strip().split(",") tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None: if self.shuffle_keep_tokens is None:
random.shuffle(tokens) random.shuffle(tokens)
else: else:
@ -274,7 +285,7 @@ class BaseDataset(torch.utils.data.Dataset):
tokens = tokens[self.shuffle_keep_tokens:] tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens) random.shuffle(tokens)
tokens = keep_tokens + tokens tokens = keep_tokens + tokens
caption = ",".join(tokens).strip() caption = ", ".join(tokens)
for str_from, str_to in self.replacements.items(): for str_from, str_to in self.replacements.items():
if str_from == "": if str_from == "":
@ -598,6 +609,17 @@ class BaseDataset(torch.utils.data.Dataset):
images.append(image) images.append(image)
latents_list.append(latents) latents_list.append(latents)
# dropoutの決定
is_drop_out = False
if self.dropout_rate > 0 and random.random() < self.dropout_rate:
is_drop_out = True
if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0:
is_drop_out = True
if is_drop_out:
caption = ""
print(f"Drop caption out: {self.process_caption(image_info.caption)}")
else:
caption = self.process_caption(image_info.caption) caption = self.process_caption(image_info.caption)
captions.append(caption) captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future if not self.token_padding_disabled: # this option might be omitted in future
@ -1377,7 +1399,7 @@ def verify_training_args(args: argparse.Namespace):
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool): def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
# dataset common # dataset common
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--shuffle_caption", action="store_true", parser.add_argument("--shuffle_caption", action="store_true",
@ -1408,6 +1430,14 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
parser.add_argument("--bucket_no_upscale", action="store_true", parser.add_argument("--bucket_no_upscale", action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
parser.add_argument("--caption_dropout_rate", type=float, default=0,
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
if support_dreambooth: if support_dreambooth:
# DreamBooth dataset # DreamBooth dataset
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")

View File

@ -99,6 +99,7 @@ def save_configuration(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -195,6 +196,7 @@ def open_configuration(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -275,6 +277,7 @@ def train_model(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -380,7 +383,7 @@ def train_model(
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"' run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"'
run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop' # run_cmd += f' --caption_dropout_rate="0.1" --caption_dropout_every_n_epochs=1' # --random_crop'
if v2: if v2:
run_cmd += ' --v2' run_cmd += ' --v2'
@ -440,7 +443,7 @@ def train_model(
else: else:
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"' run_cmd += f' --lr_scheduler_num_cycles="{epoch}"'
if not lr_scheduler_power == '': if not lr_scheduler_power == '':
run_cmd += f' --output_name="{lr_scheduler_power}"' run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"'
run_cmd += run_cmd_training( run_cmd += run_cmd_training(
learning_rate=learning_rate, learning_rate=learning_rate,
@ -476,6 +479,8 @@ def train_model(
bucket_no_upscale=bucket_no_upscale, bucket_no_upscale=bucket_no_upscale,
random_crop=random_crop, random_crop=random_crop,
bucket_reso_steps=bucket_reso_steps, bucket_reso_steps=bucket_reso_steps,
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
caption_dropout_rate=caption_dropout_rate,
) )
print(run_cmd) print(run_cmd)
@ -725,6 +730,7 @@ def lora_tab(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -805,6 +811,7 @@ def lora_tab(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
] ]
button_open_config.click( button_open_config.click(

View File

@ -5,6 +5,7 @@
import math import math
import os import os
from typing import List
import torch import torch
from library import train_util from library import train_util
@ -98,7 +99,7 @@ class LoRANetwork(torch.nn.Module):
self.alpha = alpha self.alpha = alpha
# create module instances # create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
loras = [] loras = []
for name, module in root_module.named_modules(): for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules: if module.__class__.__name__ in target_replace_modules:

View File

@ -94,6 +94,7 @@ def save_configuration(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -193,6 +194,7 @@ def open_configuration(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -272,6 +274,7 @@ def train_model(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -453,6 +456,8 @@ def train_model(
bucket_no_upscale=bucket_no_upscale, bucket_no_upscale=bucket_no_upscale,
random_crop=random_crop, random_crop=random_crop,
bucket_reso_steps=bucket_reso_steps, bucket_reso_steps=bucket_reso_steps,
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
caption_dropout_rate=caption_dropout_rate,
) )
run_cmd += f' --token_string="{token_string}"' run_cmd += f' --token_string="{token_string}"'
run_cmd += f' --init_word="{init_word}"' run_cmd += f' --init_word="{init_word}"'
@ -709,6 +714,7 @@ def ti_tab(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -783,6 +789,7 @@ def ti_tab(
bucket_no_upscale, bucket_no_upscale,
random_crop, random_crop,
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
] ]
button_open_config.click( button_open_config.click(

View File

@ -4,7 +4,7 @@ import argparse
import shutil import shutil
import math import math
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2): def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=1):
# Split the max_resolution string by "," and strip any whitespaces # Split the max_resolution string by "," and strip any whitespaces
max_resolutions = [res.strip() for res in max_resolution.split(',')] max_resolutions = [res.strip() for res in max_resolution.split(',')]
@ -58,6 +58,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
base, ext = os.path.splitext(filename) base, ext = os.path.splitext(filename)
new_filename = base + '+' + max_resolution + '.jpg' new_filename = base + '+' + max_resolution + '.jpg'
# copy caption file with right name if one exist
if os.path.exists(os.path.join(src_img_folder, base + '.txt')):
shutil.copy(os.path.join(src_img_folder, base + '.txt'), os.path.join(dst_img_folder, new_filename + '.txt'))
# Save resized image in dst_img_folder # Save resized image in dst_img_folder
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")

View File

@ -38,8 +38,13 @@ def train(args):
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale, args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
if args.no_token_padding: if args.no_token_padding:
train_dataset.disable_token_padding() train_dataset.disable_token_padding()
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
train_dataset.make_buckets() train_dataset.make_buckets()
if args.debug_dataset: if args.debug_dataset:
@ -204,6 +209,8 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.epoch_current = epoch + 1
# 指定したステップ数までText Encoderを学習するepoch最初の状態 # 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train() unet.train()
# train==True is required to enable gradient_checkpointing # train==True is required to enable gradient_checkpointing
@ -327,7 +334,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False) train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)

View File

@ -132,6 +132,10 @@ def train(args):
args.bucket_reso_steps, args.bucket_no_upscale, args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset) args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
train_dataset.make_buckets() train_dataset.make_buckets()
if args.debug_dataset: if args.debug_dataset:
@ -376,6 +380,9 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.epoch_current = epoch + 1
metadata["ss_epoch"] = str(epoch+1) metadata["ss_epoch"] = str(epoch+1)
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
@ -509,7 +516,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True) train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")

View File

@ -478,7 +478,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True) train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],