Merge pull request #474 from bmaltais/dev

v21.3.6
This commit is contained in:
bmaltais 2023-03-28 11:55:22 -04:00 committed by GitHub
commit 79bffae5ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 4724 additions and 2450 deletions

View File

@ -213,6 +213,16 @@ This will store your a backup file with your current locally installed pip packa
## Change History ## Change History
* 2023/03/28 (v21.3.6)
- Fix issues when `--persistent_data_loader_workers` is specified.
- The batch members of the bucket are not shuffled.
- `--caption_dropout_every_n_epochs` does not work.
- These issues occurred because the epoch transition was not recognized correctly. Thanks to u-haru for reporting the issue.
- Fix an issue that images are loaded twice in Windows environment.
- Add Min-SNR Weighting strategy. Details are in [#308](https://github.com/kohya-ss/sd-scripts/pull/308). Thank you to AI-Casanova for this great work!
- Add `--min_snr_gamma` option to training scripts, 5 is recommended by paper.
- The Min SNR gamma fiels can be found unser the advanced training tab in all trainers.
- Fixed the error while images are ended with capital image extensions. Thanks to @kvzn. https://github.com/bmaltais/kohya_ss/pull/454
* 2023/03/26 (v21.3.5) * 2023/03/26 (v21.3.5)
- Fix for https://github.com/bmaltais/kohya_ss/issues/230 - Fix for https://github.com/bmaltais/kohya_ss/issues/230
- Added detection for Google Colab to not bring up the GUI file/folder window on the platform. Instead it will only use the file/folder path provided in the input field. - Added detection for Google Colab to not bring up the GUI file/folder window on the platform. Instead it will only use the file/folder path provided in the input field.

View File

@ -108,6 +108,7 @@ def save_configuration(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -216,6 +217,7 @@ def open_configuration(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -306,6 +308,7 @@ def train_model(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -335,12 +338,17 @@ def train_model(
subfolders = [ subfolders = [
f f
for f in os.listdir(train_data_dir) for f in os.listdir(train_data_dir)
if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.') if os.path.isdir(os.path.join(train_data_dir, f))
and not f.startswith('.')
] ]
# Check if subfolders are present. If not let the user know and return # Check if subfolders are present. If not let the user know and return
if not subfolders: if not subfolders:
print('\033[33mNo subfolders were found in', train_data_dir, ' can\'t train\...033[0m') print(
'\033[33mNo subfolders were found in',
train_data_dir,
" can't train\...033[0m",
)
return return
total_steps = 0 total_steps = 0
@ -351,18 +359,24 @@ def train_model(
try: try:
repeats = int(folder.split('_')[0]) repeats = int(folder.split('_')[0])
except ValueError: except ValueError:
print('\033[33mSubfolder', folder, 'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m') print(
'\033[33mSubfolder',
folder,
"does not have a proper repeat value, please correct the name or remove it... can't train...\033[0m",
)
continue continue
# Count the number of images in the folder # Count the number of images in the folder
num_images = len( num_images = len(
[ [
f f
for f in os.listdir(os.path.join(train_data_dir, folder)) for f, lower_f in (
if f.endswith('.jpg') (file, file.lower())
or f.endswith('.jpeg') for file in os.listdir(
or f.endswith('.png') os.path.join(train_data_dir, folder)
or f.endswith('.webp') )
)
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
] ]
) )
@ -377,7 +391,11 @@ def train_model(
print('\033[33mFolder', folder, ':', steps, 'steps\033[0m') print('\033[33mFolder', folder, ':', steps, 'steps\033[0m')
if total_steps == 0: if total_steps == 0:
print('\033[33mNo images were found in folder', train_data_dir, '... please rectify!\033[0m') print(
'\033[33mNo images were found in folder',
train_data_dir,
'... please rectify!\033[0m',
)
return return
# Print the result # Print the result
@ -386,7 +404,9 @@ def train_model(
if reg_data_dir == '': if reg_data_dir == '':
reg_factor = 1 reg_factor = 1
else: else:
print('\033[94mRegularisation images are used... Will double the number of steps required...\033[0m') print(
'\033[94mRegularisation images are used... Will double the number of steps required...\033[0m'
)
reg_factor = 2 reg_factor = 2
# calculate max_train_steps # calculate max_train_steps
@ -498,6 +518,7 @@ def train_model(
noise_offset=noise_offset, noise_offset=noise_offset,
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size, vae_batch_size=vae_batch_size,
min_snr_gamma=min_snr_gamma,
) )
run_cmd += run_cmd_sample( run_cmd += run_cmd_sample(
@ -705,6 +726,7 @@ def dreambooth_tab(
noise_offset, noise_offset,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -806,6 +828,7 @@ def dreambooth_tab(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
] ]
button_open_config.click( button_open_config.click(

View File

@ -6,6 +6,7 @@ import gc
import math import math
import os import os
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -19,10 +20,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def collate_fn(examples):
return examples[0]
def train(args): def train(args):
@ -64,6 +63,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
@ -187,16 +191,21 @@ def train(args):
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@ -255,13 +264,14 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch + 1
for m in training_models: for m in training_models:
m.train() m.train()
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
@ -302,7 +312,14 @@ def train(args):
else: else:
target = noise target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") if args.min_snr_gamma:
# do not mean over batch dimension for snr weight
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@ -396,6 +413,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")

View File

@ -104,7 +104,9 @@ def save_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters,vae_batch_size, additional_parameters,
vae_batch_size,
min_snr_gamma,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -217,7 +219,9 @@ def open_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters,vae_batch_size, additional_parameters,
vae_batch_size,
min_snr_gamma,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -312,7 +316,9 @@ def train_model(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters,vae_batch_size, additional_parameters,
vae_batch_size,
min_snr_gamma,
): ):
if check_if_model_exist(output_name, output_dir, save_model_as): if check_if_model_exist(output_name, output_dir, save_model_as):
return return
@ -368,8 +374,10 @@ def train_model(
image_num = len( image_num = len(
[ [
f f
for f in os.listdir(image_folder) for f, lower_f in (
if f.endswith('.jpg') or f.endswith('.png') or f.endswith('.webp') (file, file.lower()) for file in os.listdir(image_folder)
)
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
] ]
) )
print(f'image_num = {image_num}') print(f'image_num = {image_num}')
@ -471,6 +479,7 @@ def train_model(
noise_offset=noise_offset, noise_offset=noise_offset,
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size, vae_batch_size=vae_batch_size,
min_snr_gamma=min_snr_gamma,
) )
run_cmd += run_cmd_sample( run_cmd += run_cmd_sample(
@ -688,6 +697,7 @@ def finetune_tab():
noise_offset, noise_offset,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -783,6 +793,7 @@ def finetune_tab():
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
] ]
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)

File diff suppressed because it is too large Load Diff

View File

@ -840,6 +840,7 @@ def gradio_advanced_training():
xformers = gr.Checkbox(label='Use xformers', value=True) xformers = gr.Checkbox(label='Use xformers', value=True)
color_aug = gr.Checkbox(label='Color augmentation', value=False) color_aug = gr.Checkbox(label='Color augmentation', value=False)
flip_aug = gr.Checkbox(label='Flip augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1)
with gr.Row(): with gr.Row():
bucket_no_upscale = gr.Checkbox( bucket_no_upscale = gr.Checkbox(
label="Don't upscale bucket resolution", value=True label="Don't upscale bucket resolution", value=True
@ -914,6 +915,7 @@ def gradio_advanced_training():
noise_offset, noise_offset,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
) )
@ -949,13 +951,15 @@ def run_cmd_advanced_training(**kwargs):
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}' f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
if int(kwargs.get('bucket_reso_steps', 64)) >= 1 if int(kwargs.get('bucket_reso_steps', 64)) >= 1
else '', else '',
f' --min_snr_gamma={int(kwargs.get("min_snr_gamma", 0))}'
if int(kwargs.get('min_snr_gamma', 0)) >= 1
else '',
' --save_state' if kwargs.get('save_state') else '', ' --save_state' if kwargs.get('save_state') else '',
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '', ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
' --color_aug' if kwargs.get('color_aug') else '', ' --color_aug' if kwargs.get('color_aug') else '',
' --flip_aug' if kwargs.get('flip_aug') else '', ' --flip_aug' if kwargs.get('flip_aug') else '',
' --shuffle_caption' if kwargs.get('shuffle_caption') else '', ' --shuffle_caption' if kwargs.get('shuffle_caption') else '',
' --gradient_checkpointing' ' --gradient_checkpointing' if kwargs.get('gradient_checkpointing')
if kwargs.get('gradient_checkpointing')
else '', else '',
' --full_fp16' if kwargs.get('full_fp16') else '', ' --full_fp16' if kwargs.get('full_fp16') else '',
' --xformers' if kwargs.get('xformers') else '', ' --xformers' if kwargs.get('xformers') else '',

View File

@ -4,6 +4,7 @@ from dataclasses import (
dataclass, dataclass,
) )
import functools import functools
import random
from textwrap import dedent, indent from textwrap import dedent, indent
import json import json
from pathlib import Path from pathlib import Path
@ -56,6 +57,8 @@ class BaseSubsetParams:
caption_dropout_rate: float = 0.0 caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0 caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0 caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
@dataclass @dataclass
class DreamBoothSubsetParams(BaseSubsetParams): class DreamBoothSubsetParams(BaseSubsetParams):
@ -137,6 +140,8 @@ class ConfigSanitizer:
"random_crop": bool, "random_crop": bool,
"shuffle_caption": bool, "shuffle_caption": bool,
"keep_tokens": int, "keep_tokens": int,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
} }
# DO means DropOut # DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = { DO_SUBSET_ASCENDABLE_SCHEMA = {
@ -406,6 +411,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
flip_aug: {subset.flip_aug} flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range} face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop} random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""), " ") """), " ")
if is_dreambooth: if is_dreambooth:
@ -422,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
print(info) print(info)
# make buckets first because it determines the length of dataset # make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets): for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
dataset.make_buckets() dataset.make_buckets()
dataset.set_seed(seed)
return DatasetGroup(datasets) return DatasetGroup(datasets)
@ -491,7 +501,6 @@ def load_user_config(file: str) -> dict:
return config return config
# for config test # for config test
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -0,0 +1,18 @@
import torch
import argparse
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
loss = loss * snr_weight
return loss
def add_custom_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨")

View File

@ -136,7 +136,7 @@ def gradio_extract_lora_tab():
dim = gr.Slider( dim = gr.Slider(
minimum=4, minimum=4,
maximum=1024, maximum=1024,
label='Network Dimension', label='Network Dimension (Rank)',
value=128, value=128,
step=1, step=1,
interactive=True, interactive=True,
@ -144,8 +144,8 @@ def gradio_extract_lora_tab():
conv_dim = gr.Slider( conv_dim = gr.Slider(
minimum=0, minimum=0,
maximum=1024, maximum=1024,
label='Conv Dimension', label='Conv Dimension (Rank)',
value=0, value=128,
step=1, step=1,
interactive=True, interactive=True,
) )

View File

@ -1046,10 +1046,14 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
key_count = len(state_dict.keys()) key_count = len(state_dict.keys())
new_ckpt = {'state_dict': state_dict} new_ckpt = {'state_dict': state_dict}
if 'epoch' in checkpoint: # epoch and global_step are sometimes not int
epochs += checkpoint['epoch'] try:
if 'global_step' in checkpoint: if 'epoch' in checkpoint:
steps += checkpoint['global_step'] epochs += checkpoint['epoch']
if 'global_step' in checkpoint:
steps += checkpoint['global_step']
except:
pass
new_ckpt['epoch'] = epochs new_ckpt['epoch'] = epochs
new_ckpt['global_step'] = steps new_ckpt['global_step'] = steps

View File

@ -276,6 +276,8 @@ class BaseSubset:
caption_dropout_rate: float, caption_dropout_rate: float,
caption_dropout_every_n_epochs: int, caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float, caption_tag_dropout_rate: float,
token_warmup_min: int,
token_warmup_step: Union[float, int],
) -> None: ) -> None:
self.image_dir = image_dir self.image_dir = image_dir
self.num_repeats = num_repeats self.num_repeats = num_repeats
@ -289,6 +291,9 @@ class BaseSubset:
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # NN<1ならN*max_train_stepsステップ目でタグの数が最大になる
self.img_count = 0 self.img_count = 0
@ -309,6 +314,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@ -324,6 +331,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) )
self.is_reg = is_reg self.is_reg = is_reg
@ -351,6 +360,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@ -366,6 +377,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) )
self.metadata_file = metadata_file self.metadata_file = metadata_file
@ -404,6 +417,10 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.current_step: int = 0
self.max_train_steps: int = 0
self.seed: int = 0
# augmentation # augmentation
self.aug_helper = AugHelper() self.aug_helper = AugHelper()
@ -419,9 +436,19 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {} self.replacements = {}
def set_seed(self, seed):
self.seed = seed
def set_current_epoch(self, epoch): def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch self.current_epoch = epoch
self.shuffle_buckets()
def set_current_step(self, step):
self.current_step = step
def set_max_train_steps(self, max_train_steps):
self.max_train_steps = max_train_steps
def set_tag_frequency(self, dir_name, captions): def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {}) frequency_for_dir = self.tag_frequency.get(dir_name, {})
@ -452,7 +479,16 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out: if is_drop_out:
caption = "" caption = ""
else: else:
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")]
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len]
def dropout_tags(tokens): def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0: if subset.caption_tag_dropout_rate <= 0:
@ -464,10 +500,10 @@ class BaseDataset(torch.utils.data.Dataset):
return l return l
fixed_tokens = [] fixed_tokens = []
flex_tokens = [t.strip() for t in caption.strip().split(",")] flex_tokens = tokens[:]
if subset.keep_tokens > 0: if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens] fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = flex_tokens[subset.keep_tokens :] flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption: if subset.shuffle_caption:
random.shuffle(flex_tokens) random.shuffle(flex_tokens)
@ -637,6 +673,9 @@ class BaseDataset(torch.utils.data.Dataset):
self._length = len(self.buckets_indices) self._length = len(self.buckets_indices)
def shuffle_buckets(self): def shuffle_buckets(self):
# set random seed for this epoch
random.seed(self.seed + self.current_epoch)
random.shuffle(self.buckets_indices) random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle() self.bucket_manager.shuffle()
@ -1043,7 +1082,7 @@ class DreamBoothDataset(BaseDataset):
self.register_image(info, subset) self.register_image(info, subset)
n += info.num_repeats n += info.num_repeats
else: else:
info.num_repeats += 1 info.num_repeats += 1 # rewrite registered info
n += 1 n += 1
if n >= num_train_images: if n >= num_train_images:
break break
@ -1104,6 +1143,8 @@ class FineTuningDataset(BaseDataset):
# path情報を作る # path情報を作る
if os.path.exists(image_key): if os.path.exists(image_key):
abs_path = image_key abs_path = image_key
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else: else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz") npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path): if os.path.exists(npz_path):
@ -1285,6 +1326,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets: for dataset in self.datasets:
dataset.set_current_epoch(epoch) dataset.set_current_epoch(epoch)
def set_current_step(self, step):
for dataset in self.datasets:
dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
for dataset in self.datasets:
dataset.set_max_train_steps(max_train_steps)
def disable_token_padding(self): def disable_token_padding(self):
for dataset in self.datasets: for dataset in self.datasets:
dataset.disable_token_padding() dataset.disable_token_padding()
@ -1292,37 +1341,55 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def debug_dataset(train_dataset, show_input_ids=False): def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します") print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
train_dataset.set_current_epoch(1) epoch = 1
k = 0 while True:
indices = list(range(len(train_dataset))) print(f"epoch: {epoch}")
random.shuffle(indices)
for i, idx in enumerate(indices): steps = (epoch - 1) * len(train_dataset) + 1
example = train_dataset[idx] indices = list(range(len(train_dataset)))
if example["latents"] is not None: random.shuffle(indices)
print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate( k = 0
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) for i, idx in enumerate(indices):
): train_dataset.set_current_epoch(epoch)
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') train_dataset.set_current_step(steps)
if show_input_ids: print(f"steps: {steps} ({i + 1}/{len(train_dataset)})")
print(f"input ids: {iid}")
if example["images"] is not None: example = train_dataset[idx]
im = example["images"][j] if example["latents"] is not None:
print(f"image size: {im.size()}") print(f"sample has latents from npz file: {example['latents'].size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) for j, (ik, cap, lw, iid) in enumerate(
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
im = im[:, :, ::-1] # RGB -> BGR (OpenCV) ):
if os.name == "nt": # only windows print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
cv2.imshow("img", im) if show_input_ids:
k = cv2.waitKey() print(f"input ids: {iid}")
cv2.destroyAllWindows() if example["images"] is not None:
if k == 27: im = example["images"][j]
break print(f"image size: {im.size()}")
if k == 27 or (example["images"] is None and i >= 8): im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27 or k == ord("s") or k == ord("e"):
break
steps += 1
if k == ord("e"):
break
if k == 27 or (example["images"] is None and i >= 8):
k = 27
break
if k == 27:
break break
epoch += 1
def glob_images(directory, base="*"): def glob_images(directory, base="*"):
img_paths = [] img_paths = []
@ -1331,8 +1398,8 @@ def glob_images(directory, base="*"):
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else: else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
# img_paths = list(set(img_paths)) # 重複を排除 img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort() img_paths.sort()
return img_paths return img_paths
@ -1344,8 +1411,8 @@ def glob_images_pathlib(dir_path, recursive):
else: else:
for ext in IMAGE_EXTENSIONS: for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob("*" + ext)) image_paths += list(dir_path.glob("*" + ext))
# image_paths = list(set(image_paths)) # 重複を排除 image_paths = list(set(image_paths)) # 重複を排除
# image_paths.sort() image_paths.sort()
return image_paths return image_paths
@ -2038,6 +2105,20 @@ def add_dataset_arguments(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
) )
parser.add_argument(
"--token_warmup_min",
type=int,
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)
parser.add_argument(
"--token_warmup_step",
type=float,
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
if support_caption_dropout: if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない # Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
@ -2972,3 +3053,24 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# endregion # endregion
# collate_fn用 epoch,stepはmultiprocessing.Value
class collater_class:
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
def __call__(self, examples):
worker_info = torch.utils.data.get_worker_info()
# worker_info is None in the main process
if worker_info is not None:
dataset = worker_info.dataset
else:
dataset = self.dataset
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]

View File

@ -123,7 +123,9 @@ def save_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters,vae_batch_size, additional_parameters,
vae_batch_size,
min_snr_gamma,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -240,7 +242,9 @@ def open_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters,vae_batch_size, additional_parameters,
vae_batch_size,
min_snr_gamma,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -348,7 +352,9 @@ def train_model(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters,vae_batch_size, additional_parameters,
vae_batch_size,
min_snr_gamma,
): ):
print_only_bool = True if print_only.get('label') == 'True' else False print_only_bool = True if print_only.get('label') == 'True' else False
@ -419,11 +425,13 @@ def train_model(
num_images = len( num_images = len(
[ [
f f
for f in os.listdir(os.path.join(train_data_dir, folder)) for f, lower_f in (
if f.endswith('.jpg') (file, file.lower())
or f.endswith('.jpeg') for file in os.listdir(
or f.endswith('.png') os.path.join(train_data_dir, folder)
or f.endswith('.webp') )
)
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
] ]
) )
@ -591,6 +599,7 @@ def train_model(
noise_offset=noise_offset, noise_offset=noise_offset,
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size, vae_batch_size=vae_batch_size,
min_snr_gamma=min_snr_gamma,
) )
run_cmd += run_cmd_sample( run_cmd += run_cmd_sample(
@ -649,10 +658,12 @@ def lora_tab(
v_parameterization, v_parameterization,
save_model_as, save_model_as,
model_list, model_list,
) = gradio_source_model(save_model_as_choices = [ ) = gradio_source_model(
'ckpt', save_model_as_choices=[
'safetensors', 'ckpt',
]) 'safetensors',
]
)
with gr.Tab('Folders'): with gr.Tab('Folders'):
with gr.Row(): with gr.Row():
@ -897,6 +908,7 @@ def lora_tab(
noise_offset, noise_offset,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -1015,6 +1027,7 @@ def lora_tab(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
] ]
button_open_config.click( button_open_config.click(
@ -1104,7 +1117,7 @@ def UI(**kwargs):
if kwargs.get('inbrowser', False): if kwargs.get('inbrowser', False):
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False) launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
if kwargs.get('listen', True): if kwargs.get('listen', True):
launch_kwargs['server_name'] = "0.0.0.0" launch_kwargs['server_name'] = '0.0.0.0'
print(launch_kwargs) print(launch_kwargs)
interface.launch(**launch_kwargs) interface.launch(**launch_kwargs)
@ -1128,7 +1141,9 @@ if __name__ == '__main__':
'--inbrowser', action='store_true', help='Open in browser' '--inbrowser', action='store_true', help='Open in browser'
) )
parser.add_argument( parser.add_argument(
'--listen', action='store_true', help='Launch gradio with server name 0.0.0.0, allowing LAN access' '--listen',
action='store_true',
help='Launch gradio with server name 0.0.0.0, allowing LAN access',
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -11,6 +11,8 @@ import numpy as np
MIN_SV = 1e-6 MIN_SV = 1e-6
# Model save and load functions
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if model_util.is_safetensors(file_name): if model_util.is_safetensors(file_name):
sd = load_file(file_name) sd = load_file(file_name)
@ -39,12 +41,13 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
torch.save(model, file_name) torch.save(model, file_name)
# Indexing functions
def index_sv_cumulative(S, target): def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S)) original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0)/original_sum cumulative_sums = torch.cumsum(S, dim=0)/original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1 index = int(torch.searchsorted(cumulative_sums, target)) + 1
if index >= len(S): index = max(1, min(index, len(S)-1))
index = len(S) - 1
return index return index
@ -54,8 +57,16 @@ def index_sv_fro(S, target):
s_fro_sq = float(torch.sum(S_squared)) s_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
if index >= len(S): index = max(1, min(index, len(S)-1))
index = len(S) - 1
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv/target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S)-1))
return index return index
@ -125,26 +136,24 @@ def merge_linear(lora_down, lora_up, device):
return weight return weight
# Calculate new rank
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {} param_dict = {}
if dynamic_method=="sv_ratio": if dynamic_method=="sv_ratio":
# Calculate new dim and alpha based off ratio # Calculate new dim and alpha based off ratio
max_sv = S[0] new_rank = index_sv_ratio(S, dynamic_param) + 1
min_sv = max_sv/dynamic_param
new_rank = max(torch.sum(S > min_sv).item(),1)
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_cumulative": elif dynamic_method=="sv_cumulative":
# Calculate new dim and alpha based off cumulative sum # Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) new_rank = index_sv_cumulative(S, dynamic_param) + 1
new_rank = max(new_rank, 1)
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
elif dynamic_method=="sv_fro": elif dynamic_method=="sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares # Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param) new_rank = index_sv_fro(S, dynamic_param) + 1
new_rank = min(max(new_rank, 1), len(S)-1)
new_alpha = float(scale*new_rank) new_alpha = float(scale*new_rank)
else: else:
new_rank = rank new_rank = rank
@ -172,7 +181,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict["new_alpha"] = new_alpha param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank)/s_sum param_dict["sum_retained"] = (s_rank)/s_sum
param_dict["fro_retained"] = fro_percent param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0]/S[new_rank] param_dict["max_ratio"] = S[0]/S[new_rank - 1]
return param_dict return param_dict

View File

@ -112,7 +112,9 @@ def save_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters,vae_batch_size, additional_parameters,
vae_batch_size,
min_snr_gamma,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -225,7 +227,9 @@ def open_configuration(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters,vae_batch_size, additional_parameters,
vae_batch_size,
min_snr_gamma,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -320,7 +324,9 @@ def train_model(
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
additional_parameters,vae_batch_size, additional_parameters,
vae_batch_size,
min_snr_gamma,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -375,11 +381,13 @@ def train_model(
num_images = len( num_images = len(
[ [
f f
for f in os.listdir(os.path.join(train_data_dir, folder)) for f, lower_f in (
if f.endswith('.jpg') (file, file.lower())
or f.endswith('.jpeg') for file in os.listdir(
or f.endswith('.png') os.path.join(train_data_dir, folder)
or f.endswith('.webp') )
)
if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp'))
] ]
) )
@ -512,6 +520,7 @@ def train_model(
noise_offset=noise_offset, noise_offset=noise_offset,
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size, vae_batch_size=vae_batch_size,
min_snr_gamma=min_snr_gamma,
) )
run_cmd += f' --token_string="{token_string}"' run_cmd += f' --token_string="{token_string}"'
run_cmd += f' --init_word="{init_word}"' run_cmd += f' --init_word="{init_word}"'
@ -570,10 +579,12 @@ def ti_tab(
v_parameterization, v_parameterization,
save_model_as, save_model_as,
model_list, model_list,
) = gradio_source_model(save_model_as_choices = [ ) = gradio_source_model(
'ckpt', save_model_as_choices=[
'safetensors', 'ckpt',
]) 'safetensors',
]
)
with gr.Tab('Folders'): with gr.Tab('Folders'):
with gr.Row(): with gr.Row():
@ -775,6 +786,7 @@ def ti_tab(
noise_offset, noise_offset,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -882,6 +894,7 @@ def ti_tab(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,
] ]
button_open_config.click( button_open_config.click(

426
train_db - Copy.py Normal file
View File

@ -0,0 +1,426 @@
# DreamBooth training
# XXX dropped option: fine_tune
import gc
import time
import argparse
import itertools
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False)
cache_latents = args.cache_latents
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args)
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
collater = train_util.collater_class(current_epoch,current_step)
if args.no_token_padding:
train_dataset_group.disable_token_padding()
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
if args.gradient_accumulation_steps > 1:
print(
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
)
print(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
)
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
# verify load/save model formats
if load_stable_diffusion_format:
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
src_diffusers_model_path = None
else:
src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path
if args.save_model_as is None:
save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors
else:
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加
text_encoder.requires_grad_(train_text_encoder)
if not train_text_encoder:
print("Text Encoder is not trained.")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
if train_text_encoder:
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
else:
trainable_params = unet.parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth")
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch+1
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
with accelerator.accumulate(unet):
with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
if train_text_encoder:
params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
else:
params_to_clip = unet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end(
args,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--no_token_padding",
action="store_true",
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作",
)
parser.add_argument(
"--stop_text_encoder_training",
type=int,
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@ -8,6 +8,7 @@ import itertools
import math import math
import os import os
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -21,10 +22,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def collate_fn(examples):
return examples[0]
def train(args): def train(args):
@ -59,6 +58,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.no_token_padding: if args.no_token_padding:
train_dataset_group.disable_token_padding() train_dataset_group.disable_token_padding()
@ -152,16 +156,21 @@ def train(args):
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
if args.stop_text_encoder_training is None: if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
@ -229,7 +238,7 @@ def train(args):
loss_total = 0.0 loss_total = 0.0
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch + 1
# 指定したステップ数までText Encoderを学習するepoch最初の状態 # 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train() unet.train()
@ -238,6 +247,7 @@ def train(args):
text_encoder.train() text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める # 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training: if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}") print(f"stop text encoder training at step {global_step}")
@ -291,6 +301,9 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)
@ -390,6 +403,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument( parser.add_argument(
"--no_token_padding", "--no_token_padding",

710
train_network - Copy.py Normal file
View File

@ -0,0 +1,710 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse
import gc
import math
import os
import random
import time
import json
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
import library.train_util as train_util
from library.train_util import (
DreamBoothDataset,
)
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if args.network_train_unet_only:
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
else:
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
return logs
def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
use_user_config = args.dataset_config is not None
if args.seed is not None:
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
collater = train_util.collater_class(current_epoch,current_step)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# work on low-ram device
if args.lowram:
text_encoder.to("cuda")
unet.to("cuda")
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# prepare network
import sys
sys.path.append(os.path.dirname(__file__))
print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
net_kwargs = {}
if args.network_args is not None:
for net_arg in args.network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
if args.network_weights is not None:
print("load network weights from:", args.network_weights)
network.load_weights(args.network_weights)
train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
network.enable_gradient_checkpointing() # may have no effect
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
network.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_unet and train_text_encoder:
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
elif train_unet:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
elif train_text_encoder:
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.to(accelerator.device)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works
if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
text_encoder.eval()
# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
network.prepare_grad_etc(text_encoder, unet)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
if is_main_process:
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
# TODO refactor metadata creation and move to util
metadata = {
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
"ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset_group.num_train_images,
"ss_num_reg_images": train_dataset_group.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs,
"ss_gradient_checkpointing": args.gradient_checkpointing,
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
"ss_max_train_steps": args.max_train_steps,
"ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_alpha": args.network_alpha, # some networks may not use this value
"ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2),
"ss_clip_skip": args.clip_skip,
"ss_max_token_length": args.max_token_length,
"ss_cache_latents": bool(args.cache_latents),
"ss_seed": args.seed,
"ss_lowram": args.lowram,
"ss_noise_offset": args.noise_offset,
"ss_training_comment": args.training_comment, # will not be updated after training
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
"ss_max_grad_norm": args.max_grad_norm,
"ss_caption_dropout_rate": args.caption_dropout_rate,
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
}
if use_user_config:
# save metadata of multiple datasets
# NOTE: pack "ss_datasets" value as json one time
# or should also pack nested collections as json?
datasets_metadata = []
tag_frequency = {} # merge tag frequency for metadata editor
dataset_dirs_info = {} # merge subset dirs for metadata editor
for dataset in train_dataset_group.datasets:
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
dataset_metadata = {
"is_dreambooth": is_dreambooth_dataset,
"batch_size_per_device": dataset.batch_size,
"num_train_images": dataset.num_train_images, # includes repeating
"num_reg_images": dataset.num_reg_images,
"resolution": (dataset.width, dataset.height),
"enable_bucket": bool(dataset.enable_bucket),
"min_bucket_reso": dataset.min_bucket_reso,
"max_bucket_reso": dataset.max_bucket_reso,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
}
subsets_metadata = []
for subset in dataset.subsets:
subset_metadata = {
"img_count": subset.img_count,
"num_repeats": subset.num_repeats,
"color_aug": bool(subset.color_aug),
"flip_aug": bool(subset.flip_aug),
"random_crop": bool(subset.random_crop),
"shuffle_caption": bool(subset.shuffle_caption),
"keep_tokens": subset.keep_tokens,
}
image_dir_or_metadata_file = None
if subset.image_dir:
image_dir = os.path.basename(subset.image_dir)
subset_metadata["image_dir"] = image_dir
image_dir_or_metadata_file = image_dir
if is_dreambooth_dataset:
subset_metadata["class_tokens"] = subset.class_tokens
subset_metadata["is_reg"] = subset.is_reg
if subset.is_reg:
image_dir_or_metadata_file = None # not merging reg dataset
else:
metadata_file = os.path.basename(subset.metadata_file)
subset_metadata["metadata_file"] = metadata_file
image_dir_or_metadata_file = metadata_file # may overwrite
subsets_metadata.append(subset_metadata)
# merge dataset dir: not reg subset only
# TODO update additional-network extension to show detailed dataset config from metadata
if image_dir_or_metadata_file is not None:
# datasets may have a certain dir multiple times
v = image_dir_or_metadata_file
i = 2
while v in dataset_dirs_info:
v = image_dir_or_metadata_file + f" ({i})"
i += 1
image_dir_or_metadata_file = v
dataset_dirs_info[image_dir_or_metadata_file] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
dataset_metadata["subsets"] = subsets_metadata
datasets_metadata.append(dataset_metadata)
# merge tag frequency:
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
# なので、ここで複数datasetの回数を合算してもあまり意味はない
if ds_dir_name in tag_frequency:
continue
tag_frequency[ds_dir_name] = ds_freq_for_dir
metadata["ss_datasets"] = json.dumps(datasets_metadata)
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
else:
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
assert (
len(train_dataset_group.datasets) == 1
), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
dataset = train_dataset_group.datasets[0]
dataset_dirs_info = {}
reg_dataset_dirs_info = {}
if use_dreambooth_method:
for subset in dataset.subsets:
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
else:
for subset in dataset.subsets:
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
"n_repeats": subset.num_repeats,
"img_count": subset.img_count,
}
metadata.update(
{
"ss_batch_size_per_device": args.train_batch_size,
"ss_total_batch_size": total_batch_size,
"ss_resolution": args.resolution,
"ss_color_aug": bool(args.color_aug),
"ss_flip_aug": bool(args.flip_aug),
"ss_random_crop": bool(args.random_crop),
"ss_shuffle_caption": bool(args.shuffle_caption),
"ss_enable_bucket": bool(dataset.enable_bucket),
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
"ss_min_bucket_reso": dataset.min_bucket_reso,
"ss_max_bucket_reso": dataset.max_bucket_reso,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
"ss_bucket_info": json.dumps(dataset.bucket_info),
}
)
# add extra args
if args.network_args:
metadata["ss_network_args"] = json.dumps(net_kwargs)
# for key, value in net_kwargs.items():
# metadata["ss_arg_" + key] = value
# model name and hash
if args.pretrained_model_name_or_path is not None:
sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name
if args.vae is not None:
vae_name = args.vae
if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name
metadata = {k: str(v) for k, v in metadata.items()}
# make minimum metadata for filtering
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
minimum_metadata = {}
for key in minimum_keys:
if key in metadata:
minimum_metadata[key] = metadata[key]
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("network_train")
loss_list = []
loss_total = 0.0
del train_dataset_group
for epoch in range(num_train_epochs):
if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch+1
metadata["ss_epoch"] = str(epoch + 1)
network.on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
with torch.set_grad_enabled(train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
metadata["ss_training_finished_at"] = str(time.time())
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
if is_main_process:
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# end of epoch
metadata["ss_epoch"] = str(num_train_epochs)
metadata["ss_training_finished_at"] = str(time.time())
if is_main_process:
network = unwrap_model(network)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール")
parser.add_argument(
"--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)"
)
parser.add_argument(
"--network_alpha",
type=float,
default=1,
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定",
)
parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
parser.add_argument(
"--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する"
)
parser.add_argument(
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@ -8,6 +8,7 @@ import random
import time import time
import json import json
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -23,10 +24,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
def collate_fn(examples):
return examples[0]
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する
@ -100,6 +99,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
@ -185,11 +189,12 @@ def train(args):
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
@ -200,6 +205,9 @@ def train(args):
if is_main_process: if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@ -488,22 +496,23 @@ def train(args):
noise_scheduler = DDPMScheduler( noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("network_train") accelerator.init_trackers("network_train")
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0
del train_dataset_group
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
if is_main_process: if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch+1
metadata["ss_epoch"] = str(epoch + 1) metadata["ss_epoch"] = str(epoch + 1)
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network): with accelerator.accumulate(network):
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
@ -528,7 +537,6 @@ def train(args):
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long() timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
@ -549,6 +557,9 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)
@ -652,6 +663,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument( parser.add_argument(

View File

@ -0,0 +1,589 @@
import importlib
import argparse
import gc
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
def train(args):
if args.output_name is None:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents
if args.seed is not None:
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# Convert the init_word to token_id
if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
)
else:
init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == args.num_vectors_per_token
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_ids is not None:
for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
embeddings = load_weights(args.weights)
assert len(token_ids) == len(
embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
collater = train_util.collater_class(current_epoch,current_step)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print("use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
for tmpl in templates:
captions.append(tmpl.format(replace_to))
train_dataset_group.add_replacement("", captions)
if args.num_vectors_per_token > 1:
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
else:
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
else:
unet.eval()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch+1
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
index_no_updates
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
# end of epoch
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
print("model saved.")
def save_weights(file, updated_embs, save_dtype):
state_dict = {"emb_params": updated_embs}
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file)
else:
torch.save(state_dict, file) # can be loaded in Web UI
def load_weights(file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
data = load_file(file)
else:
# compatible to Web UI's file format
data = torch.load(file, map_location="cpu")
if type(data) != dict:
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
if "string_to_param" in data: # textual inversion embeddings
data = data["string_to_param"]
if hasattr(data, "_parameters"): # support old PyTorch?
data = getattr(data, "_parameters")
emb = next(iter(data.values()))
if type(emb) != torch.Tensor:
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
if len(emb.size()) == 1:
emb = emb.unsqueeze(0)
return emb
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="pt",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
parser.add_argument(
"--token_string",
type=str,
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--use_object_template",
action="store_true",
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
)
parser.add_argument(
"--use_style_template",
action="store_true",
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@ -4,6 +4,7 @@ import gc
import math import math
import os import os
import toml import toml
from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -17,6 +18,8 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
imagenet_templates_small = [ imagenet_templates_small = [
"a photo of a {}", "a photo of a {}",
@ -71,10 +74,6 @@ imagenet_style_templates_small = [
] ]
def collate_fn(examples):
return examples[0]
def train(args): def train(args):
if args.output_name is None: if args.output_name is None:
args.output_name = args.token_string args.output_name = args.token_string
@ -185,6 +184,11 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template: if use_template:
print("use template for training captions. is object: {args.use_object_template}") print("use template for training captions. is object: {args.use_object_template}")
@ -250,7 +254,7 @@ def train(args):
train_dataset_group, train_dataset_group,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collater,
num_workers=n_workers, num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
@ -260,6 +264,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@ -331,12 +338,14 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) current_epoch.value = epoch+1
text_encoder.train() text_encoder.train()
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
@ -378,6 +387,9 @@ def train(args):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
@ -534,6 +546,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument( parser.add_argument(
"--save_model_as", "--save_model_as",