Update to latest sd-script code

This commit is contained in:
bmaltais 2023-03-20 08:47:00 -04:00
parent 09ad7961e3
commit ccae80186a
23 changed files with 5678 additions and 3640 deletions

View File

@ -41,6 +41,9 @@ If you run on Linux and would like to use the GUI, there is now a port of it as
## Installation ## Installation
### Runpod
Follow the instructions found in this discussion: https://github.com/bmaltais/kohya_ss/discussions/379
### Ubuntu ### Ubuntu
In the terminal, run In the terminal, run
@ -189,6 +192,19 @@ This will store your a backup file with your current locally installed pip packa
## Change History ## Change History
* 2023/03/19 (v21.3.0)
- Add a function to load training config with `.toml` to each training script. Thanks to Linaqruf for this great contribution!
- Specify `.toml` file with `--config_file`. `.toml` file has `key=value` entries. Keys are same as command line options. See [#241](https://github.com/kohya-ss/sd-scripts/pull/241) for details.
- All sub-sections are combined to a single dictionary (the section names are ignored.)
- Omitted arguments are the default values for command line arguments.
- Command line args override the arguments in `.toml`.
- With `--output_config` option, you can output current command line options to the `.toml` specified with`--config_file`. Please use as a template.
- Add `--lr_scheduler_type` and `--lr_scheduler_args` arguments for custom LR scheduler to each training script. Thanks to Isotr0py! [#271](https://github.com/kohya-ss/sd-scripts/pull/271)
- Same as the optimizer.
- Add sample image generation with weight and no length limit. Thanks to mio2333! [#288](https://github.com/kohya-ss/sd-scripts/pull/288)
- `( )`, `(xxxx:1.2)` and `[ ]` can be used.
- Fix exception on training model in diffusers format with `train_network.py` Thanks to orenwang! [#290](https://github.com/kohya-ss/sd-scripts/pull/290)
- Add warning if you are about to overwrite an existing model: https://github.com/bmaltais/kohya_ss/issues/404
* 2023/03/19 (v21.2.5): * 2023/03/19 (v21.2.5):
- Fix basic captioning logic - Fix basic captioning logic
- Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1. - Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1.

View File

@ -26,6 +26,7 @@ from library.common_gui import (
gradio_source_model, gradio_source_model,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist,
) )
from library.tensorboard_gui import ( from library.tensorboard_gui import (
gradio_tensorboard, gradio_tensorboard,
@ -104,7 +105,8 @@ def save_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -210,15 +212,16 @@ def open_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path(file_path) file_path = get_file_path(file_path)
@ -298,7 +301,8 @@ def train_model(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -321,6 +325,9 @@ def train_model(
msgbox('Output folder path is missing') msgbox('Output folder path is missing')
return return
if check_if_model_exist(output_name, output_dir, save_model_as):
return
# Get a list of all subfolders in train_data_dir # Get a list of all subfolders in train_data_dir
subfolders = [ subfolders = [
f f
@ -787,7 +794,7 @@ def dreambooth_tab(
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )
button_load_config.click( button_load_config.click(
open_configuration, open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list, inputs=[dummy_db_false, config_file_name] + settings_list,

View File

@ -5,6 +5,7 @@ import argparse
import gc import gc
import math import math
import os import os
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -15,351 +16,391 @@ from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util import library.config_util as config_util
from library.config_util import ( from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
def collate_fn(examples): def collate_fn(examples):
return examples[0] return examples[0]
def train(args): def train(args):
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents cache_latents = args.cache_latents
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"] ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) print(
else: "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
user_config = { ", ".join(ignored)
"datasets": [{ )
"subsets": [{ )
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}]
}]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。")
return
if cache_latents:
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
# verify load/save model formats
if load_stable_diffusion_format:
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
src_diffusers_model_path = None
else:
src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path
if args.save_model_as is None:
save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors
else:
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors'
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# Diffusers版のxformers使用フラグを設定する関数
def set_diffusers_xformers_flag(model, valid):
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children():
fn_recursive_set_mem_eff(child)
fn_recursive_set_mem_eff(model)
# モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers:
print("Use xformers by Diffusers")
set_diffusers_xformers_flag(unet, True)
else:
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
if args.train_text_encoder:
print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
training_models.append(text_encoder)
else:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False) # text encoderは学習しない
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
text_encoder.train() # required for gradient_checkpointing
else: else:
text_encoder.eval() user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
if not cache_latents: blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
vae.requires_grad_(False) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
for m in training_models: if args.debug_dataset:
m.requires_grad_(True) train_util.debug_dataset(train_dataset_group)
params = [] return
for m in training_models: if len(train_dataset_group) == 0:
params.extend(m.parameters()) print(
params_to_optimize = params "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
# 学習に必要なクラスを準備する if cache_latents:
print("prepare optimizer, data loader etc.") assert (
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# dataloaderを準備する # acceleratorを準備する
# DataLoaderのプロセス数0はメインプロセスになる print("prepare accelerator")
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで accelerator, unwrap_model = train_util.prepare_accelerator(args)
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する # mixed precisionに対応した型を用意しておき適宜castする
if args.max_train_epochs is not None: weight_dtype, save_dtype = train_util.prepare_dtype(args)
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # モデルを読み込む
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # verify load/save model formats
if args.full_fp16: if load_stable_diffusion_format:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
print("enable full fp16 training.") src_diffusers_model_path = None
unet.to(weight_dtype) else:
text_encoder.to(weight_dtype) src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path
# acceleratorがなんかよろしくやってくれるらしい if args.save_model_as is None:
if args.train_text_encoder: save_stable_diffusion_format = load_stable_diffusion_format
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( use_safetensors = args.use_safetensors
unet, text_encoder, optimizer, train_dataloader, lr_scheduler) else:
else: save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする # Diffusers版のxformers使用フラグを設定する関数
if args.full_fp16: def set_diffusers_xformers_flag(model, valid):
train_util.patch_accelerator_for_fp16_training(accelerator) # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
# resumeする # Recursively walk through all the children.
if args.resume is not None: # Any children which exposes the set_use_memory_efficient_attention_xformers method
print(f"resume training from state: {args.resume}") # gets the message
accelerator.load_state(args.resume) def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
# epoch数を計算する for child in module.children():
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) fn_recursive_set_mem_eff(child)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する fn_recursive_set_mem_eff(model)
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") # モデルに xformers とか memory efficient attention を組み込む
global_step = 0 if args.diffusers_xformers:
print("Use xformers by Diffusers")
set_diffusers_xformers_flag(unet, True)
else:
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", # 学習を準備する
num_train_timesteps=1000, clip_sample=False) if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if accelerator.is_main_process: # 学習を準備する:モデルを適切な状態にする
accelerator.init_trackers("finetuning") training_models = []
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
for epoch in range(num_train_epochs): if args.train_text_encoder:
print(f"epoch {epoch+1}/{num_train_epochs}") print("enable text encoder training")
train_dataset_group.set_current_epoch(epoch + 1) if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
training_models.append(text_encoder)
else:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False) # text encoderは学習しない
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
text_encoder.train() # required for gradient_checkpointing
else:
text_encoder.eval()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
for m in training_models: for m in training_models:
m.train() m.requires_grad_(True)
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
loss_total = 0 # 学習に必要なクラスを準備する
for step, batch in enumerate(train_dataloader): print("prepare optimizer, data loader etc.")
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
with torch.set_grad_enabled(args.train_text_encoder): # dataloaderを準備する
# Get the text embedding for conditioning # DataLoaderのプロセス数0はメインプロセスになる
input_ids = batch["input_ids"].to(accelerator.device) n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
encoder_hidden_states = train_util.get_hidden_states( train_dataloader = torch.utils.data.DataLoader(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# Sample noise that we'll add to the latents # 学習ステップ数を計算する
noise = torch.randn_like(latents, device=latents.device) if args.max_train_epochs is not None:
if args.noise_offset: args.max_train_steps = args.max_train_epochs * len(train_dataloader)
# https://www.crosslabs.org//blog/diffusion-with-offset-noise print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image # lr schedulerを用意する
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
# (this is the forward diffusion process) if args.full_fp16:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# Predict the noise residual # acceleratorがなんかよろしくやってくれるらしい
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.v_parameterization: # 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
# v-parameterization training if args.full_fp16:
target = noise_scheduler.get_velocity(latents, noise, timesteps) train_util.patch_accelerator_for_fp16_training(accelerator)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") # resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
accelerator.backward(loss) # epoch数を計算する
if accelerator.sync_gradients and args.max_grad_norm != 0.0: num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
params_to_clip = [] num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
for m in training_models: if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
params_to_clip.extend(m.parameters()) args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() # 学習する
lr_scheduler.step() total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
optimizer.zero_grad(set_to_none=True) print("running training / 学習開始")
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
# Checks if the accelerator has performed an optimization step behind the scenes progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
if accelerator.sync_gradients: global_step = 0
progress_bar.update(1)
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if accelerator.is_main_process:
if args.logging_dir is not None: accelerator.init_trackers("finetuning")
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step)
# TODO moving averageにする for epoch in range(num_train_epochs):
loss_total += current_loss print(f"epoch {epoch+1}/{num_train_epochs}")
avr_loss = loss_total / (step+1) train_dataset_group.set_current_epoch(epoch + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps: for m in training_models:
break m.train()
if args.logging_dir is not None: loss_total = 0
logs = {"loss/epoch": loss_total / len(train_dataloader)} for step, batch in enumerate(train_dataloader):
accelerator.log(logs, step=epoch+1) with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
accelerator.wait_for_everyone() with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
if args.save_every_n_epochs is not None: # Sample noise that we'll add to the latents
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path noise = torch.randn_like(latents, device=latents.device)
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, if args.noise_offset:
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) # https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
is_main_process = accelerator.is_main_process timesteps = timesteps.long()
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
accelerator.end_training() # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if args.save_state: # Predict the noise residual
train_util.save_state_on_train_end(args, accelerator) noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del accelerator # この後メモリを使うのでこれは消す if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
if is_main_process: loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, accelerator.backward(loss)
save_dtype, epoch, global_step, text_encoder, unet, vae) if accelerator.sync_gradients and args.max_grad_norm != 0.0:
print("model saved.") params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end(
args,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True) train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False) train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
parser.add_argument("--diffusers_xformers", action='store_true', parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
help='use xformers by diffusers / Diffusersでxformersを使用する') parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
args = parser.parse_args() args = parser.parse_args()
train(args) args = train_util.read_config_from_file(args, parser)
train(args)

View File

@ -4,7 +4,7 @@ from pathlib import Path
from typing import List from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
import os
def main(args): def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
@ -29,6 +29,9 @@ def main(args):
caption_path = image_path.with_suffix(args.caption_extension) caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip() caption = caption_path.read_text(encoding='utf-8').strip()
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}

View File

@ -4,7 +4,7 @@ from pathlib import Path
from typing import List from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
import os
def main(args): def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
@ -29,6 +29,9 @@ def main(args):
tags_path = image_path.with_suffix(args.caption_extension) tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip() tags = tags_path.read_text(encoding='utf-8').strip()
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}

View File

@ -125,7 +125,7 @@ def main(args):
tag_text = "" tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh and i < len(tags): if p >= args.thresh and i < len(tags):
tag_text += ", " + (tags[i].replace("_", " ") if args.replace_underscores else tags[i]) tag_text += ", " + tags[i]
if len(tag_text) > 0: if len(tag_text) > 0:
tag_text = tag_text[2:] # 最初の ", " を消す tag_text = tag_text[2:] # 最初の ", " を消す
@ -190,7 +190,6 @@ if __name__ == '__main__':
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--replace_underscores", action="store_true", help="replace underscores in tags with spaces / タグのアンダースコアをスペースに置き換える")
args = parser.parse_args() args = parser.parse_args()

View File

@ -20,6 +20,7 @@ from library.common_gui import (
run_cmd_training, run_cmd_training,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist,
) )
from library.tensorboard_gui import ( from library.tensorboard_gui import (
gradio_tensorboard, gradio_tensorboard,
@ -102,7 +103,8 @@ def save_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -214,15 +216,16 @@ def open_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path(file_path) file_path = get_file_path(file_path)
@ -308,8 +311,12 @@ def train_model(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
if check_if_model_exist(output_name, output_dir, save_model_as):
return
# create caption json file # create caption json file
if generate_caption_database: if generate_caption_database:
if not os.path.exists(train_dir): if not os.path.exists(train_dir):
@ -677,7 +684,8 @@ def finetune_tab():
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
noise_offset,additional_parameters, noise_offset,
additional_parameters,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -770,7 +778,8 @@ def finetune_tab():
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
] ]
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)
@ -781,7 +790,7 @@ def finetune_tab():
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )
button_load_config.click( button_load_config.click(
open_configuration, open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list, inputs=[dummy_db_false, config_file_name] + settings_list,

View File

@ -122,7 +122,7 @@ def gradio_basic_caption_gui_tab():
label='Replacement text', label='Replacement text',
placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing', placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing',
interactive=True, interactive=True,
) )
caption_button = gr.Button('Caption images') caption_button = gr.Button('Caption images')
caption_button.click( caption_button.click(
caption_images, caption_images,

View File

@ -1,7 +1,7 @@
from tkinter import filedialog, Tk from tkinter import filedialog, Tk
import os import os
import gradio as gr import gradio as gr
from easygui import msgbox import easygui
import shutil import shutil
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -31,6 +31,34 @@ V1_MODELS = [
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
def check_if_model_exist(output_name, output_dir, save_model_as):
if save_model_as in ['diffusers', 'diffusers_safetendors']:
ckpt_folder = os.path.join(output_dir, output_name)
if os.path.isdir(ckpt_folder):
msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?'
if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
print(
'Aborting training due to existing model with same name...'
)
return True
elif save_model_as in ['ckpt', 'safetensors']:
ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as)
if os.path.isfile(ckpt_file):
msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?'
if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
print(
'Aborting training due to existing model with same name...'
)
return True
else:
print(
'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...'
)
return False
return False
def update_my_data(my_data): def update_my_data(my_data):
# Update optimizer based on use_8bit_adam flag # Update optimizer based on use_8bit_adam flag
use_8bit_adam = my_data.get('use_8bit_adam', False) use_8bit_adam = my_data.get('use_8bit_adam', False)
@ -38,11 +66,16 @@ def update_my_data(my_data):
my_data['optimizer'] = 'AdamW8bit' my_data['optimizer'] = 'AdamW8bit'
elif 'optimizer' not in my_data: elif 'optimizer' not in my_data:
my_data['optimizer'] = 'AdamW' my_data['optimizer'] = 'AdamW'
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
model_list = my_data.get('model_list', []) model_list = my_data.get('model_list', [])
pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '') pretrained_model_name_or_path = my_data.get(
if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS: 'pretrained_model_name_or_path', ''
)
if (
not model_list
or pretrained_model_name_or_path not in ALL_PRESET_MODELS
):
my_data['model_list'] = 'custom' my_data['model_list'] = 'custom'
# Convert epoch and save_every_n_epochs values to int if they are strings # Convert epoch and save_every_n_epochs values to int if they are strings
@ -78,7 +111,7 @@ def update_my_data(my_data):
# # If Pretrained model name or path is not one of the preset models then set the preset_model to custom # # If Pretrained model name or path is not one of the preset models then set the preset_model to custom
# if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS: # if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
# my_data['model_list'] = 'custom' # my_data['model_list'] = 'custom'
# # Fix old config files that contain epoch as str instead of int # # Fix old config files that contain epoch as str instead of int
# for key in ['epoch', 'save_every_n_epochs']: # for key in ['epoch', 'save_every_n_epochs']:
# value = my_data.get(key, -1) # value = my_data.get(key, -1)
@ -87,10 +120,10 @@ def update_my_data(my_data):
# my_data[key] = int(value) # my_data[key] = int(value)
# else: # else:
# my_data[key] = -1 # my_data[key] = -1
# if my_data.get('LoRA_type', 'Standard') == 'LoCon': # if my_data.get('LoRA_type', 'Standard') == 'LoCon':
# my_data['LoRA_type'] = 'LyCORIS/LoCon' # my_data['LoRA_type'] = 'LyCORIS/LoCon'
# return my_data # return my_data
@ -265,11 +298,11 @@ def get_saveasfilename_path(
def add_pre_postfix( def add_pre_postfix(
folder: str = '', folder: str = '',
prefix: str = '', prefix: str = '',
postfix: str = '', postfix: str = '',
caption_file_ext: str = '.caption' caption_file_ext: str = '.caption',
) -> None: ) -> None:
""" """
Add prefix and/or postfix to the content of caption files within a folder. Add prefix and/or postfix to the content of caption files within a folder.
If no caption files are found, create one with the requested prefix and/or postfix. If no caption files are found, create one with the requested prefix and/or postfix.
@ -285,7 +318,9 @@ def add_pre_postfix(
return return
image_extensions = ('.jpg', '.jpeg', '.png', '.webp') image_extensions = ('.jpg', '.jpeg', '.png', '.webp')
image_files = [f for f in os.listdir(folder) if f.lower().endswith(image_extensions)] image_files = [
f for f in os.listdir(folder) if f.lower().endswith(image_extensions)
]
for image_file in image_files: for image_file in image_files:
caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext
@ -303,7 +338,10 @@ def add_pre_postfix(
prefix_separator = ' ' if prefix else '' prefix_separator = ' ' if prefix else ''
postfix_separator = ' ' if postfix else '' postfix_separator = ' ' if postfix else ''
f.write(f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}') f.write(
f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}'
)
# def add_pre_postfix( # def add_pre_postfix(
# folder='', prefix='', postfix='', caption_file_ext='.caption' # folder='', prefix='', postfix='', caption_file_ext='.caption'
@ -335,11 +373,11 @@ def add_pre_postfix(
def has_ext_files(folder_path: str, file_extension: str) -> bool: def has_ext_files(folder_path: str, file_extension: str) -> bool:
""" """
Check if there are any files with the specified extension in the given folder. Check if there are any files with the specified extension in the given folder.
Args: Args:
folder_path (str): Path to the folder containing files. folder_path (str): Path to the folder containing files.
file_extension (str): Extension of the files to look for. file_extension (str): Extension of the files to look for.
Returns: Returns:
bool: True if files with the specified extension are found, False otherwise. bool: True if files with the specified extension are found, False otherwise.
""" """
@ -348,15 +386,16 @@ def has_ext_files(folder_path: str, file_extension: str) -> bool:
return True return True
return False return False
def find_replace( def find_replace(
folder_path: str = '', folder_path: str = '',
caption_file_ext: str = '.caption', caption_file_ext: str = '.caption',
search_text: str = '', search_text: str = '',
replace_text: str = '' replace_text: str = '',
) -> None: ) -> None:
""" """
Find and replace text in caption files within a folder. Find and replace text in caption files within a folder.
Args: Args:
folder_path (str, optional): Path to the folder containing caption files. folder_path (str, optional): Path to the folder containing caption files.
caption_file_ext (str, optional): Extension of the caption files. caption_file_ext (str, optional): Extension of the caption files.
@ -364,7 +403,7 @@ def find_replace(
replace_text (str, optional): Text to replace the search text with. replace_text (str, optional): Text to replace the search text with.
""" """
print('Running caption find/replace') print('Running caption find/replace')
if not has_ext_files(folder_path, caption_file_ext): if not has_ext_files(folder_path, caption_file_ext):
msgbox( msgbox(
f'No files with extension {caption_file_ext} were found in {folder_path}...' f'No files with extension {caption_file_ext} were found in {folder_path}...'
@ -374,10 +413,14 @@ def find_replace(
if search_text == '': if search_text == '':
return return
caption_files = [f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)] caption_files = [
f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)
]
for caption_file in caption_files: for caption_file in caption_files:
with open(os.path.join(folder_path, caption_file), 'r', errors='ignore') as f: with open(
os.path.join(folder_path, caption_file), 'r', errors='ignore'
) as f:
content = f.read() content = f.read()
content = content.replace(search_text, replace_text) content = content.replace(search_text, replace_text)
@ -385,6 +428,7 @@ def find_replace(
with open(os.path.join(folder_path, caption_file), 'w') as f: with open(os.path.join(folder_path, caption_file), 'w') as f:
f.write(content) f.write(content)
# def find_replace(folder='', caption_file_ext='.caption', find='', replace=''): # def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
# print('Running caption find/replace') # print('Running caption find/replace')
# if not has_ext_files(folder, caption_file_ext): # if not has_ext_files(folder, caption_file_ext):
@ -477,17 +521,15 @@ def set_pretrained_model_name_or_path_input(
if ( if (
str(pretrained_model_name_or_path) in V1_MODELS str(pretrained_model_name_or_path) in V1_MODELS
or str(pretrained_model_name_or_path) in V2_BASE_MODELS or str(pretrained_model_name_or_path) in V2_BASE_MODELS
or str(pretrained_model_name_or_path) or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS
in V_PARAMETERIZATION_MODELS
): ):
pretrained_model_name_or_path = '' pretrained_model_name_or_path = ''
v2 = False v2 = False
v_parameterization = False v_parameterization = False
return model_list, pretrained_model_name_or_path, v2, v_parameterization return model_list, pretrained_model_name_or_path, v2, v_parameterization
def set_v2_checkbox(
model_list, v2, v_parameterization def set_v2_checkbox(model_list, v2, v_parameterization):
):
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
if str(model_list) in V2_BASE_MODELS: if str(model_list) in V2_BASE_MODELS:
v2 = True v2 = True
@ -504,6 +546,7 @@ def set_v2_checkbox(
return v2, v_parameterization return v2, v_parameterization
def set_model_list( def set_model_list(
model_list, model_list,
pretrained_model_name_or_path, pretrained_model_name_or_path,
@ -515,7 +558,7 @@ def set_model_list(
model_list = 'custom' model_list = 'custom'
else: else:
model_list = pretrained_model_name_or_path model_list = pretrained_model_name_or_path
return model_list, v2, v_parameterization return model_list, v2, v_parameterization
@ -538,7 +581,11 @@ def gradio_config():
interactive=True, interactive=True,
) )
button_load_config = gr.Button('Load 💾', elem_id='open_folder') button_load_config = gr.Button('Load 💾', elem_id='open_folder')
config_file_name.change(remove_doublequote, inputs=[config_file_name], outputs=[config_file_name]) config_file_name.change(
remove_doublequote,
inputs=[config_file_name],
outputs=[config_file_name],
)
return ( return (
button_open_config, button_open_config,
button_save_config, button_save_config,
@ -614,8 +661,18 @@ def gradio_source_model():
v_parameterization = gr.Checkbox( v_parameterization = gr.Checkbox(
label='v_parameterization', value=False label='v_parameterization', value=False
) )
v2.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False) v2.change(
v_parameterization.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False) set_v2_checkbox,
inputs=[model_list, v2, v_parameterization],
outputs=[v2, v_parameterization],
show_progress=False,
)
v_parameterization.change(
set_v2_checkbox,
inputs=[model_list, v2, v_parameterization],
outputs=[v2, v_parameterization],
show_progress=False,
)
model_list.change( model_list.change(
set_pretrained_model_name_or_path_input, set_pretrained_model_name_or_path_input,
inputs=[ inputs=[
@ -671,7 +728,9 @@ def gradio_training(
step=1, step=1,
) )
epoch = gr.Number(label='Epoch', value=1, precision=0) epoch = gr.Number(label='Epoch', value=1, precision=0)
save_every_n_epochs = gr.Number(label='Save every N epochs', value=1, precision=0) save_every_n_epochs = gr.Number(
label='Save every N epochs', value=1, precision=0
)
caption_extension = gr.Textbox( caption_extension = gr.Textbox(
label='Caption Extension', label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption', placeholder='(Optional) Extension for caption files. default: .caption',
@ -788,7 +847,7 @@ def run_cmd_training(**kwargs):
if kwargs.get('save_precision') if kwargs.get('save_precision')
else '', else '',
f' --seed="{kwargs.get("seed", "")}"' f' --seed="{kwargs.get("seed", "")}"'
if kwargs.get('seed') != "" if kwargs.get('seed') != ''
else '', else '',
f' --caption_extension="{kwargs.get("caption_extension", "")}"' f' --caption_extension="{kwargs.get("caption_extension", "")}"'
if kwargs.get('caption_extension') if kwargs.get('caption_extension')
@ -807,7 +866,7 @@ def run_cmd_training(**kwargs):
def gradio_advanced_training(): def gradio_advanced_training():
with gr.Row(): with gr.Row():
additional_parameters = gr.Textbox( additional_parameters = gr.Textbox(
label='Additional parameters', label='Additional parameters',
placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"', placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"',
) )
with gr.Row(): with gr.Row():
@ -964,7 +1023,7 @@ def run_cmd_advanced_training(**kwargs):
f' --noise_offset={float(kwargs.get("noise_offset", 0))}' f' --noise_offset={float(kwargs.get("noise_offset", 0))}'
if not kwargs.get('noise_offset', '') == '' if not kwargs.get('noise_offset', '') == ''
else '', else '',
f' {kwargs.get("additional_parameters", "")}' f' {kwargs.get("additional_parameters", "")}',
] ]
run_cmd = ''.join(options) run_cmd = ''.join(options)
return run_cmd return run_cmd

View File

@ -153,6 +153,14 @@ def gradio_extract_lora_tab():
extract_button.click( extract_button.click(
extract_lora, extract_lora,
inputs=[model_tuned, model_org, save_to, save_precision, dim, v2, conv_dim], inputs=[
model_tuned,
model_org,
save_to,
save_precision,
dim,
v2,
conv_dim,
],
show_progress=False, show_progress=False,
) )

View File

@ -16,12 +16,23 @@ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
def extract_lycoris_locon( def extract_lycoris_locon(
db_model, base_model, output_name, device, db_model,
is_v2, mode, linear_dim, conv_dim, base_model,
linear_threshold, conv_threshold, output_name,
linear_ratio, conv_ratio, device,
linear_quantile, conv_quantile, is_v2,
use_sparse_bias, sparsity, disable_cp mode,
linear_dim,
conv_dim,
linear_threshold,
conv_threshold,
linear_ratio,
conv_ratio,
linear_quantile,
conv_quantile,
use_sparse_bias,
sparsity,
disable_cp,
): ):
# Check for caption_text_input # Check for caption_text_input
if db_model == '': if db_model == '':
@ -41,9 +52,7 @@ def extract_lycoris_locon(
msgbox('The provided base model is not a file') msgbox('The provided base model is not a file')
return return
run_cmd = ( run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"'
f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"'
)
if is_v2: if is_v2:
run_cmd += f' --is_v2' run_cmd += f' --is_v2'
run_cmd += f' --device {device}' run_cmd += f' --device {device}'
@ -89,10 +98,11 @@ def extract_lycoris_locon(
# if mode == 'threshold': # if mode == 'threshold':
# return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True) # return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True)
def update_mode(mode): def update_mode(mode):
# Create a list of possible mode values # Create a list of possible mode values
modes = ['fixed', 'threshold', 'ratio', 'quantile'] modes = ['fixed', 'threshold', 'ratio', 'quantile']
# Initialize an empty list to store visibility updates # Initialize an empty list to store visibility updates
updates = [] updates = []
@ -104,12 +114,15 @@ def update_mode(mode):
# Return the visibility updates as a tuple # Return the visibility updates as a tuple
return tuple(updates) return tuple(updates)
def gradio_extract_lycoris_locon_tab(): def gradio_extract_lycoris_locon_tab():
with gr.Tab('Extract LyCORIS LoCON'): with gr.Tab('Extract LyCORIS LoCON'):
gr.Markdown( gr.Markdown(
'This utility can extract a LyCORIS LoCon network from a finetuned model.' 'This utility can extract a LyCORIS LoCon network from a finetuned model.'
) )
lora_ext = gr.Textbox(value='*.safetensors', visible=False) # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) lora_ext = gr.Textbox(
value='*.safetensors', visible=False
) # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
model_ext_name = gr.Textbox(value='Model types', visible=False) model_ext_name = gr.Textbox(value='Model types', visible=False)
@ -161,14 +174,17 @@ def gradio_extract_lycoris_locon_tab():
) )
device = gr.Dropdown( device = gr.Dropdown(
label='Device', label='Device',
choices=['cpu', 'cuda',], choices=[
'cpu',
'cuda',
],
value='cuda', value='cuda',
interactive=True, interactive=True,
) )
is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True) is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True)
mode = gr.Dropdown( mode = gr.Dropdown(
label='Mode', label='Mode',
choices=['fixed', 'threshold','ratio','quantile'], choices=['fixed', 'threshold', 'ratio', 'quantile'],
value='fixed', value='fixed',
interactive=True, interactive=True,
) )
@ -241,7 +257,9 @@ def gradio_extract_lycoris_locon_tab():
interactive=True, interactive=True,
) )
with gr.Row(): with gr.Row():
use_sparse_bias = gr.Checkbox(label='Use sparse biais', value=False, interactive=True) use_sparse_bias = gr.Checkbox(
label='Use sparse biais', value=False, interactive=True
)
sparsity = gr.Slider( sparsity = gr.Slider(
minimum=0, minimum=0,
maximum=1, maximum=1,
@ -250,24 +268,42 @@ def gradio_extract_lycoris_locon_tab():
step=0.01, step=0.01,
interactive=True, interactive=True,
) )
disable_cp = gr.Checkbox(label='Disable CP decomposition', value=False, interactive=True) disable_cp = gr.Checkbox(
label='Disable CP decomposition', value=False, interactive=True
)
mode.change( mode.change(
update_mode, update_mode,
inputs=[mode], inputs=[mode],
outputs=[ outputs=[
fixed, threshold, ratio, quantile, fixed,
] threshold,
ratio,
quantile,
],
) )
extract_button = gr.Button('Extract LyCORIS LoCon') extract_button = gr.Button('Extract LyCORIS LoCon')
extract_button.click( extract_button.click(
extract_lycoris_locon, extract_lycoris_locon,
inputs=[db_model, base_model, output_name, device, inputs=[
is_v2, mode, linear_dim, conv_dim, db_model,
linear_threshold, conv_threshold, base_model,
linear_ratio, conv_ratio, output_name,
linear_quantile, conv_quantile, device,
use_sparse_bias, sparsity, disable_cp], is_v2,
mode,
linear_dim,
conv_dim,
linear_threshold,
conv_threshold,
linear_ratio,
conv_ratio,
linear_quantile,
conv_quantile,
use_sparse_bias,
sparsity,
disable_cp,
],
show_progress=False, show_progress=False,
) )

View File

@ -27,7 +27,9 @@ def caption_images(
return return
print(f'GIT captioning files in {train_data_dir}...') print(f'GIT captioning files in {train_data_dir}...')
run_cmd = f'.\\venv\\Scripts\\python.exe "finetune/make_captions_by_git.py"' run_cmd = (
f'.\\venv\\Scripts\\python.exe "finetune/make_captions_by_git.py"'
)
if not model_id == '': if not model_id == '':
run_cmd += f' --model_id="{model_id}"' run_cmd += f' --model_id="{model_id}"'
run_cmd += f' --batch_size="{int(batch_size)}"' run_cmd += f' --batch_size="{int(batch_size)}"'

File diff suppressed because it is too large Load Diff

View File

@ -30,15 +30,19 @@ def resize_lora(
if not os.path.isfile(model): if not os.path.isfile(model):
msgbox('The provided model is not a file') msgbox('The provided model is not a file')
return return
if dynamic_method == 'sv_ratio': if dynamic_method == 'sv_ratio':
if float(dynamic_param) < 2: if float(dynamic_param) < 2:
msgbox(f'Dynamic parameter for {dynamic_method} need to be 2 or greater...') msgbox(
f'Dynamic parameter for {dynamic_method} need to be 2 or greater...'
)
return return
if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative': if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative':
if float(dynamic_param) < 0 or float(dynamic_param) > 1: if float(dynamic_param) < 0 or float(dynamic_param) > 1:
msgbox(f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...') msgbox(
f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...'
)
return return
# Check if save_to end with one of the defines extension. If not add .safetensors. # Check if save_to end with one of the defines extension. If not add .safetensors.
@ -108,25 +112,18 @@ def gradio_resize_lora_tab():
with gr.Row(): with gr.Row():
dynamic_method = gr.Dropdown( dynamic_method = gr.Dropdown(
choices=['None', choices=['None', 'sv_ratio', 'sv_fro', 'sv_cumulative'],
'sv_ratio',
'sv_fro',
'sv_cumulative'
],
value='sv_fro', value='sv_fro',
label='Dynamic method', label='Dynamic method',
interactive=True interactive=True,
) )
dynamic_param = gr.Textbox( dynamic_param = gr.Textbox(
label='Dynamic parameter', label='Dynamic parameter',
value='0.9', value='0.9',
interactive=True, interactive=True,
placeholder='Value for the dynamic method selected.' placeholder='Value for the dynamic method selected.',
)
verbose = gr.Checkbox(
label='Verbose',
value=False
) )
verbose = gr.Checkbox(label='Verbose', value=False)
with gr.Row(): with gr.Row():
save_to = gr.Textbox( save_to = gr.Textbox(
label='Save to', label='Save to',
@ -150,7 +147,10 @@ def gradio_resize_lora_tab():
) )
device = gr.Dropdown( device = gr.Dropdown(
label='Device', label='Device',
choices=['cpu', 'cuda',], choices=[
'cpu',
'cuda',
],
value='cuda', value='cuda',
interactive=True, interactive=True,
) )

View File

@ -74,18 +74,18 @@ def run_cmd_sample(
sample_prompts, sample_prompts,
output_dir, output_dir,
): ):
output_dir = os.path.join(output_dir, "sample") output_dir = os.path.join(output_dir, 'sample')
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
run_cmd = '' run_cmd = ''
if sample_every_n_epochs == 0 and sample_every_n_steps == 0: if sample_every_n_epochs == 0 and sample_every_n_steps == 0:
return run_cmd return run_cmd
# Create the prompt file and get its path # Create the prompt file and get its path
sample_prompts_path = os.path.join(output_dir, "prompt.txt") sample_prompts_path = os.path.join(output_dir, 'prompt.txt')
with open(sample_prompts_path, 'w') as f: with open(sample_prompts_path, 'w') as f:
f.write(sample_prompts) f.write(sample_prompts)

View File

@ -163,7 +163,10 @@ def gradio_svd_merge_lora_tab():
) )
device = gr.Dropdown( device = gr.Dropdown(
label='Device', label='Device',
choices=['cpu', 'cuda',], choices=[
'cpu',
'cuda',
],
value='cuda', value='cuda',
interactive=True, interactive=True,
) )

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,9 @@ from .common_gui import get_folder_path
import os import os
def caption_images(train_data_dir, caption_extension, batch_size, thresh, replace_underscores): def caption_images(
train_data_dir, caption_extension, batch_size, thresh, replace_underscores
):
# Check for caption_text_input # Check for caption_text_input
# if caption_text_input == "": # if caption_text_input == "":
# msgbox("Caption text is missing...") # msgbox("Caption text is missing...")
@ -76,7 +78,7 @@ def gradio_wd14_caption_gui_tab():
batch_size = gr.Number( batch_size = gr.Number(
value=1, label='Batch size', interactive=True value=1, label='Batch size', interactive=True
) )
replace_underscores = gr.Checkbox( replace_underscores = gr.Checkbox(
label='Replace underscores in filenames with spaces', label='Replace underscores in filenames with spaces',
value=False, value=False,
@ -87,6 +89,12 @@ def gradio_wd14_caption_gui_tab():
caption_button.click( caption_button.click(
caption_images, caption_images,
inputs=[train_data_dir, caption_extension, batch_size, thresh, replace_underscores], inputs=[
train_data_dir,
caption_extension,
batch_size,
thresh,
replace_underscores,
],
show_progress=False, show_progress=False,
) )

View File

@ -4,6 +4,7 @@
# v3.1: Adding captionning of images to utilities # v3.1: Adding captionning of images to utilities
import gradio as gr import gradio as gr
import easygui
import json import json
import math import math
import os import os
@ -26,6 +27,7 @@ from library.common_gui import (
run_cmd_training, run_cmd_training,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist,
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -120,7 +122,8 @@ def save_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -236,15 +239,16 @@ def open_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path(file_path) file_path = get_file_path(file_path)
@ -342,10 +346,11 @@ def train_model(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
print_only_bool = True if print_only.get('label') == 'True' else False print_only_bool = True if print_only.get('label') == 'True' else False
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
return return
@ -380,6 +385,9 @@ def train_model(
) )
stop_text_encoder_training_pct = 0 stop_text_encoder_training_pct = 0
if check_if_model_exist(output_name, output_dir, save_model_as):
return
# If string is empty set string to 0. # If string is empty set string to 0.
if text_encoder_lr == '': if text_encoder_lr == '':
text_encoder_lr = 0 text_encoder_lr = 0
@ -417,7 +425,7 @@ def train_model(
or f.endswith('.webp') or f.endswith('.webp')
] ]
) )
print(f'Folder {folder}: {num_images} images found') print(f'Folder {folder}: {num_images} images found')
# Calculate the total number of steps for this folder # Calculate the total number of steps for this folder
@ -425,7 +433,7 @@ def train_model(
# Print the result # Print the result
print(f'Folder {folder}: {steps} steps') print(f'Folder {folder}: {steps} steps')
total_steps += steps total_steps += steps
# calculate max_train_steps # calculate max_train_steps
@ -492,9 +500,7 @@ def train_model(
) )
return return
run_cmd += f' --network_module=lycoris.kohya' run_cmd += f' --network_module=lycoris.kohya'
run_cmd += ( run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"'
f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"'
)
if LoRA_type == 'LyCORIS/LoHa': if LoRA_type == 'LyCORIS/LoHa':
try: try:
import lycoris import lycoris
@ -504,9 +510,7 @@ def train_model(
) )
return return
run_cmd += f' --network_module=lycoris.kohya' run_cmd += f' --network_module=lycoris.kohya'
run_cmd += ( run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"'
f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"'
)
if LoRA_type == 'Kohya LoCon': if LoRA_type == 'Kohya LoCon':
run_cmd += f' --network_module=networks.lora' run_cmd += f' --network_module=networks.lora'
run_cmd += ( run_cmd += (
@ -595,8 +599,10 @@ def train_model(
output_dir, output_dir,
) )
if print_only_bool: if print_only_bool:
print('\033[93m\nHere is the trainer command as a reference. It will not be executed:\033[0m\n') print(
'\033[93m\nHere is the trainer command as a reference. It will not be executed:\033[0m\n'
)
print('\033[96m' + run_cmd + '\033[0m\n') print('\033[96m' + run_cmd + '\033[0m\n')
else: else:
print(run_cmd) print(run_cmd)
@ -611,7 +617,9 @@ def train_model(
if not last_dir.is_dir(): if not last_dir.is_dir():
# Copy inference model for v2 if required # Copy inference model for v2 if required
save_inference_file(output_dir, v2, v_parameterization, output_name) save_inference_file(
output_dir, v2, v_parameterization, output_name
)
def lora_tab( def lora_tab(
@ -811,7 +819,12 @@ def lora_tab(
# Show of hide LoCon conv settings depending on LoRA type selection # Show of hide LoCon conv settings depending on LoRA type selection
def LoRA_type_change(LoRA_type): def LoRA_type_change(LoRA_type):
print('LoRA type changed...') print('LoRA type changed...')
if LoRA_type == 'LoCon' or LoRA_type == 'Kohya LoCon' or LoRA_type == 'LyCORIS/LoHa' or LoRA_type == 'LyCORIS/LoCon': if (
LoRA_type == 'LoCon'
or LoRA_type == 'Kohya LoCon'
or LoRA_type == 'LyCORIS/LoHa'
or LoRA_type == 'LyCORIS/LoCon'
):
return gr.Group.update(visible=True) return gr.Group.update(visible=True)
else: else:
return gr.Group.update(visible=False) return gr.Group.update(visible=False)
@ -876,7 +889,8 @@ def lora_tab(
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
noise_offset,additional_parameters, noise_offset,
additional_parameters,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -908,7 +922,7 @@ def lora_tab(
gradio_verify_lora_tab() gradio_verify_lora_tab()
button_run = gr.Button('Train model', variant='primary') button_run = gr.Button('Train model', variant='primary')
button_print = gr.Button('Print training command') button_print = gr.Button('Print training command')
# Setup gradio tensorboard buttons # Setup gradio tensorboard buttons
@ -992,7 +1006,8 @@ def lora_tab(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
] ]
button_open_config.click( button_open_config.click(
@ -1001,7 +1016,7 @@ def lora_tab(
outputs=[config_file_name] + settings_list + [LoCon_row], outputs=[config_file_name] + settings_list + [LoCon_row],
show_progress=False, show_progress=False,
) )
button_load_config.click( button_load_config.click(
open_configuration, open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list, inputs=[dummy_db_false, config_file_name] + settings_list,
@ -1028,7 +1043,7 @@ def lora_tab(
inputs=[dummy_db_false] + settings_list, inputs=[dummy_db_false] + settings_list,
show_progress=False, show_progress=False,
) )
button_print.click( button_print.click(
train_model, train_model,
inputs=[dummy_db_true] + settings_list, inputs=[dummy_db_true] + settings_list,

View File

@ -26,6 +26,7 @@ from library.common_gui import (
gradio_source_model, gradio_source_model,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist,
) )
from library.tensorboard_gui import ( from library.tensorboard_gui import (
gradio_tensorboard, gradio_tensorboard,
@ -110,7 +111,8 @@ def save_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -222,15 +224,16 @@ def open_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
ask_for_file = True if ask_for_file.get('label') == 'True' else False ask_for_file = True if ask_for_file.get('label') == 'True' else False
original_file_path = file_path original_file_path = file_path
if ask_for_file: if ask_for_file:
file_path = get_file_path(file_path) file_path = get_file_path(file_path)
@ -316,7 +319,8 @@ def train_model(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -350,6 +354,9 @@ def train_model(
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
if check_if_model_exist(output_name, output_dir, save_model_as):
return
# Get a list of all subfolders in train_data_dir # Get a list of all subfolders in train_data_dir
subfolders = [ subfolders = [
f f
@ -761,7 +768,8 @@ def ti_tab(
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
noise_offset,additional_parameters, noise_offset,
additional_parameters,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -866,7 +874,8 @@ def ti_tab(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
] ]
button_open_config.click( button_open_config.click(
@ -875,7 +884,7 @@ def ti_tab(
outputs=[config_file_name] + settings_list, outputs=[config_file_name] + settings_list,
show_progress=False, show_progress=False,
) )
button_load_config.click( button_load_config.click(
open_configuration, open_configuration,
inputs=[dummy_db_false, config_file_name] + settings_list, inputs=[dummy_db_false, config_file_name] + settings_list,

View File

@ -7,6 +7,7 @@ import argparse
import itertools import itertools
import math import math
import os import os
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -17,348 +18,392 @@ from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util import library.config_util as config_util
from library.config_util import ( from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
def collate_fn(examples): def collate_fn(examples):
return examples[0] return examples[0]
def train(args): def train(args):
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False) train_util.prepare_dataset_args(args, False)
cache_latents = args.cache_latents cache_latents = args.cache_latents
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"] ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) print(
else: "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
user_config = { ", ".join(ignored)
"datasets": [{ )
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) )
}] else:
} user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
if args.no_token_padding: if args.no_token_padding:
train_dataset_group.disable_token_padding() train_dataset_group.disable_token_padding()
if args.debug_dataset: if args.debug_dataset:
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
if cache_latents: if cache_latents:
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
print(f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong") print(
print( f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です") )
print(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
)
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
# verify load/save model formats # verify load/save model formats
if load_stable_diffusion_format: if load_stable_diffusion_format:
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
src_diffusers_model_path = None src_diffusers_model_path = None
else: else:
src_stable_diffusion_ckpt = None src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path src_diffusers_model_path = args.pretrained_model_name_or_path
if args.save_model_as is None: if args.save_model_as is None:
save_stable_diffusion_format = load_stable_diffusion_format save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors use_safetensors = args.use_safetensors
else: else:
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加
text_encoder.requires_grad_(train_text_encoder)
if not train_text_encoder:
print("Text Encoder is not trained.")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
if train_text_encoder:
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
else:
trainable_params = unet.parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth")
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
# train==True is required to enable gradient_checkpointing
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
with accelerator.accumulate(unet):
with torch.no_grad(): with torch.no_grad():
# latentに変換 train_dataset_group.cache_latents(vae)
if cache_latents: vae.to("cpu")
latents = batch["latents"].to(accelerator.device) if torch.cuda.is_available():
else: torch.cuda.empty_cache()
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() gc.collect()
latents = latents * 0.18215
b_size = latents.shape[0]
# Sample noise that we'll add to the latents # 学習を準備する:モデルを適切な状態にする
noise = torch.randn_like(latents, device=latents.device) train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
if args.noise_offset: unet.requires_grad_(True) # 念のため追加
# https://www.crosslabs.org//blog/diffusion-with-offset-noise text_encoder.requires_grad_(train_text_encoder)
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) if not train_text_encoder:
print("Text Encoder is not trained.")
# Get the text embedding for conditioning if args.gradient_checkpointing:
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): unet.enable_gradient_checkpointing()
input_ids = batch["input_ids"].to(accelerator.device) text_encoder.gradient_checkpointing_enable()
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype)
# Sample a random timestep for each image if not cache_latents:
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) vae.requires_grad_(False)
timesteps = timesteps.long() vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# Add noise to the latents according to the noise magnitude at each timestep # 学習に必要なクラスを準備する
# (this is the forward diffusion process) print("prepare optimizer, data loader etc.")
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) if train_text_encoder:
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
else:
trainable_params = unet.parameters()
# Predict the noise residual _, _, optimizer = train_util.get_optimizer(args, trainable_params)
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization: # dataloaderを準備する
# v-parameterization training # DataLoaderのプロセス数0はメインプロセスになる
target = noise_scheduler.get_velocity(latents, noise, timesteps) n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
else: train_dataloader = torch.utils.data.DataLoader(
target = noise train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") # 学習ステップ数を計算する
loss = loss.mean([1, 2, 3]) if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
loss_weights = batch["loss_weights"] # 各sampleごとのweight if args.stop_text_encoder_training is None:
loss = loss * loss_weights args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
accelerator.backward(loss) # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if args.full_fp16:
if train_text_encoder: assert (
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) args.mixed_precision == "fp16"
else: ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
params_to_clip = unet.parameters() print("enable full fp16 training.")
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) unet.to(weight_dtype)
text_encoder.to(weight_dtype)
optimizer.step() # acceleratorがなんかよろしくやってくれるらしい
lr_scheduler.step() if train_text_encoder:
optimizer.zero_grad(set_to_none=True) unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# Checks if the accelerator has performed an optimization step behind the scenes if not train_text_encoder:
if accelerator.sync_gradients: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
progress_bar.update(1)
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
current_loss = loss.detach().item() # resumeする
if args.logging_dir is not None: if args.resume is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} print(f"resume training from state: {args.resume}")
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value accelerator.load_state(args.resume)
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step)
if epoch == 0: # epoch数を計算する
loss_list.append(current_loss) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
else: num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
loss_total -= loss_list[step] if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
loss_list[step] = current_loss args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps: # 学習する
break total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
if args.logging_dir is not None: progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
logs = {"loss/epoch": loss_total / len(loss_list)} global_step = 0
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone() noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if args.save_every_n_epochs is not None: if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path accelerator.init_trackers("dreambooth")
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors,
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
is_main_process = accelerator.is_main_process # 指定したステップ数までText Encoderを学習するepoch最初の状態
if is_main_process: unet.train()
unet = unwrap_model(unet) # train==True is required to enable gradient_checkpointing
text_encoder = unwrap_model(text_encoder) if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
accelerator.end_training() for step, batch in enumerate(train_dataloader):
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
if args.save_state: with accelerator.accumulate(unet):
train_util.save_state_on_train_end(args, accelerator) with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
del accelerator # この後メモリを使うのでこれは消す # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
if is_main_process: # Get the text embedding for conditioning
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, input_ids = batch["input_ids"].to(accelerator.device)
save_dtype, epoch, global_step, text_encoder, unet, vae) encoder_hidden_states = train_util.get_hidden_states(
print("model saved.") args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
if train_text_encoder:
params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
else:
params_to_clip = unet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end(
args,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True) train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
parser.add_argument("--no_token_padding", action="store_true", parser.add_argument(
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作") "--no_token_padding",
parser.add_argument("--stop_text_encoder_training", type=int, default=None, action="store_true",
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない") help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作",
)
parser.add_argument(
"--stop_text_encoder_training",
type=int,
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
args = parser.parse_args() args = parser.parse_args()
train(args) args = train_util.read_config_from_file(args, parser)
train(args)

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@ import argparse
import gc import gc
import math import math
import os import os
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -13,8 +14,8 @@ from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util import library.config_util as config_util
from library.config_util import ( from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
imagenet_templates_small = [ imagenet_templates_small = [
@ -71,456 +72,500 @@ imagenet_style_templates_small = [
def collate_fn(examples): def collate_fn(examples):
return examples[0] return examples[0]
def train(args): def train(args):
if args.output_name is None: if args.output_name is None:
args.output_name = args.token_string args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template use_template = args.use_object_template or args.use_style_template
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents cache_latents = args.cache_latents
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# Convert the init_word to token_id # Convert the init_word to token_id
if args.init_word is not None: if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print( print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}") f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
else: )
init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
num_added_tokens = tokenizer.add_tokens(token_strings)
assert num_added_tokens == args.num_vectors_per_token, f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_ids is not None:
for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
embeddings = load_weights(args.weights)
assert len(token_ids) == len(
embeddings), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
user_config = {
"datasets": [{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)
}]
}
else: else:
print("Train with captions.") init_token_ids = None
user_config = {
"datasets": [{
"subsets": [{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}]
}]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) # add new word to tokenizer, count is num_vectors_per_token
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == args.num_vectors_per_token
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 token_ids = tokenizer.convert_tokens_to_ids(token_strings)
if use_template: print(f"tokens are added: {token_ids}")
print("use template for training captions. is object: {args.use_object_template}") assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
replace_to = " ".join(token_strings)
captions = []
for tmpl in templates:
captions.append(tmpl.format(replace_to))
train_dataset_group.add_replacement("", captions)
if args.num_vectors_per_token > 1: # Resize the token embeddings as we are adding new special tokens to the tokenizer
prompt_replacement = (args.token_string, replace_to) text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_ids is not None:
for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
embeddings = load_weights(args.weights)
assert len(token_ids) == len(
embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else: else:
prompt_replacement = None use_dreambooth_method = args.in_json is None
else: if use_dreambooth_method:
if args.num_vectors_per_token > 1: print("Use DreamBooth method.")
replace_to = " ".join(token_strings) user_config = {
train_dataset_group.add_replacement(args.token_string, replace_to) "datasets": [
prompt_replacement = (args.token_string, replace_to) {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
else: ]
prompt_replacement = None }
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
if cache_latents:
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
else:
unet.eval()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else: else:
target = noise print("Train with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
loss = loss.mean([1, 2, 3]) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
loss_weights = batch["loss_weights"] # 各sampleごとのweight # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
loss = loss * loss_weights if use_template:
print("use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
for tmpl in templates:
captions.append(tmpl.format(replace_to))
train_dataset_group.add_replacement("", captions)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし if args.num_vectors_per_token > 1:
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
else:
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
accelerator.backward(loss) if args.debug_dataset:
if accelerator.sync_gradients and args.max_grad_norm != 0.0: train_util.debug_dataset(train_dataset_group, show_input_ids=True)
params_to_clip = text_encoder.get_input_embeddings().parameters() return
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
optimizer.step() if cache_latents:
lr_scheduler.step() assert (
optimizer.zero_grad(set_to_none=True) train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# Let's make sure we don't update any embedding weights besides the newly added token # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad(): with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates] train_dataset_group.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Checks if the accelerator has performed an optimization step behind the scenes if args.gradient_checkpointing:
if accelerator.sync_gradients: unet.enable_gradient_checkpointing()
progress_bar.update(1) text_encoder.gradient_checkpointing_enable()
global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, # 学習に必要なクラスを準備する
vae, tokenizer, text_encoder, unet, prompt_replacement) print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
current_loss = loss.detach().item() # dataloaderを準備する
if args.logging_dir is not None: # DataLoaderのプロセス数0はメインプロセスになる
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value train_dataloader = torch.utils.data.DataLoader(
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] train_dataset_group,
accelerator.log(logs, step=global_step) batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
loss_total += current_loss # 学習ステップ数を計算する
avr_loss = loss_total / (step+1) if args.max_train_epochs is not None:
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} args.max_train_steps = args.max_train_epochs * len(train_dataloader)
progress_bar.set_postfix(**logs) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
if global_step >= args.max_train_steps: # lr schedulerを用意する
break lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
if args.logging_dir is not None: # acceleratorがなんかよろしくやってくれるらしい
logs = {"loss/epoch": loss_total / len(train_dataloader)} text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
accelerator.log(logs, step=epoch+1) text_encoder, optimizer, train_dataloader, lr_scheduler
)
accelerator.wait_for_everyone() index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
if args.save_every_n_epochs is not None: unet.requires_grad_(False)
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
else:
unet.eval()
def save_func(): if not cache_latents:
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
index_no_updates
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
# end of epoch
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype) save_weights(ckpt_file, updated_embs, save_dtype)
print("model saved.")
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device,
vae, tokenizer, text_encoder, unet, prompt_replacement)
# end of epoch
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
print("model saved.")
def save_weights(file, updated_embs, save_dtype): def save_weights(file, updated_embs, save_dtype):
state_dict = {"emb_params": updated_embs} state_dict = {"emb_params": updated_embs}
if save_dtype is not None: if save_dtype is not None:
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
v = state_dict[key] v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype) v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v state_dict[key] = v
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file from safetensors.torch import save_file
save_file(state_dict, file)
else: save_file(state_dict, file)
torch.save(state_dict, file) # can be loaded in Web UI else:
torch.save(state_dict, file) # can be loaded in Web UI
def load_weights(file): def load_weights(file):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file
data = load_file(file)
else:
# compatible to Web UI's file format
data = torch.load(file, map_location='cpu')
if type(data) != dict:
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
if 'string_to_param' in data: # textual inversion embeddings data = load_file(file)
data = data['string_to_param'] else:
if hasattr(data, '_parameters'): # support old PyTorch? # compatible to Web UI's file format
data = getattr(data, '_parameters') data = torch.load(file, map_location="cpu")
if type(data) != dict:
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
emb = next(iter(data.values())) if "string_to_param" in data: # textual inversion embeddings
if type(emb) != torch.Tensor: data = data["string_to_param"]
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") if hasattr(data, "_parameters"): # support old PyTorch?
data = getattr(data, "_parameters")
if len(emb.size()) == 1: emb = next(iter(data.values()))
emb = emb.unsqueeze(0) if type(emb) != torch.Tensor:
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
return emb if len(emb.size()) == 1:
emb = emb.unsqueeze(0)
return emb
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False) train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], parser.add_argument(
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt") "--save_model_as",
type=str,
default="pt",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
help="embedding weights to initialize / 学習するネットワークの初期重み") parser.add_argument(
parser.add_argument("--num_vectors_per_token", type=int, default=1, "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
help='number of vectors per token / トークンに割り当てるembeddingsの要素数') )
parser.add_argument("--token_string", type=str, default=None, parser.add_argument(
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること") "--token_string",
parser.add_argument("--init_word", type=str, default=None, type=str,
help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") default=None,
parser.add_argument("--use_object_template", action='store_true', help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する") )
parser.add_argument("--use_style_template", action='store_true', parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する") parser.add_argument(
"--use_object_template",
action="store_true",
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
)
parser.add_argument(
"--use_style_template",
action="store_true",
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
)
args = parser.parse_args() args = parser.parse_args()
train(args) args = train_util.read_config_from_file(args, parser)
train(args)