Update to latest sd-script code

This commit is contained in:
bmaltais 2023-03-20 08:47:00 -04:00
parent 09ad7961e3
commit ccae80186a
23 changed files with 5678 additions and 3640 deletions

View File

@ -41,6 +41,9 @@ If you run on Linux and would like to use the GUI, there is now a port of it as
## Installation ## Installation
### Runpod
Follow the instructions found in this discussion: https://github.com/bmaltais/kohya_ss/discussions/379
### Ubuntu ### Ubuntu
In the terminal, run In the terminal, run
@ -189,6 +192,19 @@ This will store your a backup file with your current locally installed pip packa
## Change History ## Change History
* 2023/03/19 (v21.3.0)
- Add a function to load training config with `.toml` to each training script. Thanks to Linaqruf for this great contribution!
- Specify `.toml` file with `--config_file`. `.toml` file has `key=value` entries. Keys are same as command line options. See [#241](https://github.com/kohya-ss/sd-scripts/pull/241) for details.
- All sub-sections are combined to a single dictionary (the section names are ignored.)
- Omitted arguments are the default values for command line arguments.
- Command line args override the arguments in `.toml`.
- With `--output_config` option, you can output current command line options to the `.toml` specified with`--config_file`. Please use as a template.
- Add `--lr_scheduler_type` and `--lr_scheduler_args` arguments for custom LR scheduler to each training script. Thanks to Isotr0py! [#271](https://github.com/kohya-ss/sd-scripts/pull/271)
- Same as the optimizer.
- Add sample image generation with weight and no length limit. Thanks to mio2333! [#288](https://github.com/kohya-ss/sd-scripts/pull/288)
- `( )`, `(xxxx:1.2)` and `[ ]` can be used.
- Fix exception on training model in diffusers format with `train_network.py` Thanks to orenwang! [#290](https://github.com/kohya-ss/sd-scripts/pull/290)
- Add warning if you are about to overwrite an existing model: https://github.com/bmaltais/kohya_ss/issues/404
* 2023/03/19 (v21.2.5): * 2023/03/19 (v21.2.5):
- Fix basic captioning logic - Fix basic captioning logic
- Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1. - Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1.

View File

@ -26,6 +26,7 @@ from library.common_gui import (
gradio_source_model, gradio_source_model,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist,
) )
from library.tensorboard_gui import ( from library.tensorboard_gui import (
gradio_tensorboard, gradio_tensorboard,
@ -104,7 +105,8 @@ def save_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -210,7 +212,8 @@ def open_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -298,7 +301,8 @@ def train_model(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -321,6 +325,9 @@ def train_model(
msgbox('Output folder path is missing') msgbox('Output folder path is missing')
return return
if check_if_model_exist(output_name, output_dir, save_model_as):
return
# Get a list of all subfolders in train_data_dir # Get a list of all subfolders in train_data_dir
subfolders = [ subfolders = [
f f

View File

@ -5,6 +5,7 @@ import argparse
import gc import gc
import math import math
import os import os
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -19,6 +20,7 @@ from library.config_util import (
BlueprintGenerator, BlueprintGenerator,
) )
def collate_fn(examples): def collate_fn(examples):
return examples[0] return examples[0]
@ -40,15 +42,23 @@ def train(args):
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"] ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else: else:
user_config = { user_config = {
"datasets": [{ "datasets": [
"subsets": [{ {
"subsets": [
{
"image_dir": args.train_data_dir, "image_dir": args.train_data_dir,
"metadata_file": args.in_json, "metadata_file": args.in_json,
}] }
}] ]
}
]
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
@ -58,11 +68,15 @@ def train(args):
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") print(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return return
if cache_latents: if cache_latents:
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
@ -86,7 +100,7 @@ def train(args):
save_stable_diffusion_format = load_stable_diffusion_format save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors use_safetensors = args.use_safetensors
else: else:
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# Diffusers版のxformers使用フラグを設定する関数 # Diffusers版のxformers使用フラグを設定する関数
@ -170,7 +184,13 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
@ -178,13 +198,13 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16: if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.") print("enable full fp16 training.")
unet.to(weight_dtype) unet.to(weight_dtype)
text_encoder.to(weight_dtype) text_encoder.to(weight_dtype)
@ -192,7 +212,8 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder: if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler) unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else: else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
@ -225,8 +246,9 @@ def train(args):
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0 global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", noise_scheduler = DDPMScheduler(
num_train_timesteps=1000, clip_sample=False) beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("finetuning") accelerator.init_trackers("finetuning")
@ -254,7 +276,8 @@ def train(args):
# Get the text embedding for conditioning # Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device) input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states( encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
@ -297,13 +320,17 @@ def train(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
# TODO moving averageにする # TODO moving averageにする
@ -323,8 +350,20 @@ def train(args):
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, train_util.save_sd_model_on_epoch_end(
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) args,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
@ -342,12 +381,13 @@ def train(args):
if is_main_process: if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, train_util.save_sd_model_on_train_end(
save_dtype, epoch, global_step, text_encoder, unet, vae) args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.") print("model saved.")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@ -357,9 +397,10 @@ if __name__ == '__main__':
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
parser.add_argument("--diffusers_xformers", action='store_true', parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
help='use xformers by diffusers / Diffusersでxformersを使用する')
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args) train(args)

View File

@ -4,7 +4,7 @@ from pathlib import Path
from typing import List from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
import os
def main(args): def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
@ -29,6 +29,9 @@ def main(args):
caption_path = image_path.with_suffix(args.caption_extension) caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip() caption = caption_path.read_text(encoding='utf-8').strip()
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}

View File

@ -4,7 +4,7 @@ from pathlib import Path
from typing import List from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
import os
def main(args): def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
@ -29,6 +29,9 @@ def main(args):
tags_path = image_path.with_suffix(args.caption_extension) tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip() tags = tags_path.read_text(encoding='utf-8').strip()
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata: if image_key not in metadata:
metadata[image_key] = {} metadata[image_key] = {}

View File

@ -125,7 +125,7 @@ def main(args):
tag_text = "" tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh and i < len(tags): if p >= args.thresh and i < len(tags):
tag_text += ", " + (tags[i].replace("_", " ") if args.replace_underscores else tags[i]) tag_text += ", " + tags[i]
if len(tag_text) > 0: if len(tag_text) > 0:
tag_text = tag_text[2:] # 最初の ", " を消す tag_text = tag_text[2:] # 最初の ", " を消す
@ -190,7 +190,6 @@ if __name__ == '__main__':
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--replace_underscores", action="store_true", help="replace underscores in tags with spaces / タグのアンダースコアをスペースに置き換える")
args = parser.parse_args() args = parser.parse_args()

View File

@ -20,6 +20,7 @@ from library.common_gui import (
run_cmd_training, run_cmd_training,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist,
) )
from library.tensorboard_gui import ( from library.tensorboard_gui import (
gradio_tensorboard, gradio_tensorboard,
@ -102,7 +103,8 @@ def save_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -214,7 +216,8 @@ def open_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -308,8 +311,12 @@ def train_model(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
if check_if_model_exist(output_name, output_dir, save_model_as):
return
# create caption json file # create caption json file
if generate_caption_database: if generate_caption_database:
if not os.path.exists(train_dir): if not os.path.exists(train_dir):
@ -677,7 +684,8 @@ def finetune_tab():
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
noise_offset,additional_parameters, noise_offset,
additional_parameters,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -770,7 +778,8 @@ def finetune_tab():
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
] ]
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)

View File

@ -1,7 +1,7 @@
from tkinter import filedialog, Tk from tkinter import filedialog, Tk
import os import os
import gradio as gr import gradio as gr
from easygui import msgbox import easygui
import shutil import shutil
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
@ -31,6 +31,34 @@ V1_MODELS = [
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
def check_if_model_exist(output_name, output_dir, save_model_as):
if save_model_as in ['diffusers', 'diffusers_safetendors']:
ckpt_folder = os.path.join(output_dir, output_name)
if os.path.isdir(ckpt_folder):
msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?'
if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
print(
'Aborting training due to existing model with same name...'
)
return True
elif save_model_as in ['ckpt', 'safetensors']:
ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as)
if os.path.isfile(ckpt_file):
msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?'
if not easygui.ynbox(msg, 'Overwrite Existing Model?'):
print(
'Aborting training due to existing model with same name...'
)
return True
else:
print(
'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...'
)
return False
return False
def update_my_data(my_data): def update_my_data(my_data):
# Update optimizer based on use_8bit_adam flag # Update optimizer based on use_8bit_adam flag
use_8bit_adam = my_data.get('use_8bit_adam', False) use_8bit_adam = my_data.get('use_8bit_adam', False)
@ -41,8 +69,13 @@ def update_my_data(my_data):
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
model_list = my_data.get('model_list', []) model_list = my_data.get('model_list', [])
pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '') pretrained_model_name_or_path = my_data.get(
if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS: 'pretrained_model_name_or_path', ''
)
if (
not model_list
or pretrained_model_name_or_path not in ALL_PRESET_MODELS
):
my_data['model_list'] = 'custom' my_data['model_list'] = 'custom'
# Convert epoch and save_every_n_epochs values to int if they are strings # Convert epoch and save_every_n_epochs values to int if they are strings
@ -268,7 +301,7 @@ def add_pre_postfix(
folder: str = '', folder: str = '',
prefix: str = '', prefix: str = '',
postfix: str = '', postfix: str = '',
caption_file_ext: str = '.caption' caption_file_ext: str = '.caption',
) -> None: ) -> None:
""" """
Add prefix and/or postfix to the content of caption files within a folder. Add prefix and/or postfix to the content of caption files within a folder.
@ -285,7 +318,9 @@ def add_pre_postfix(
return return
image_extensions = ('.jpg', '.jpeg', '.png', '.webp') image_extensions = ('.jpg', '.jpeg', '.png', '.webp')
image_files = [f for f in os.listdir(folder) if f.lower().endswith(image_extensions)] image_files = [
f for f in os.listdir(folder) if f.lower().endswith(image_extensions)
]
for image_file in image_files: for image_file in image_files:
caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext
@ -303,7 +338,10 @@ def add_pre_postfix(
prefix_separator = ' ' if prefix else '' prefix_separator = ' ' if prefix else ''
postfix_separator = ' ' if postfix else '' postfix_separator = ' ' if postfix else ''
f.write(f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}') f.write(
f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}'
)
# def add_pre_postfix( # def add_pre_postfix(
# folder='', prefix='', postfix='', caption_file_ext='.caption' # folder='', prefix='', postfix='', caption_file_ext='.caption'
@ -348,11 +386,12 @@ def has_ext_files(folder_path: str, file_extension: str) -> bool:
return True return True
return False return False
def find_replace( def find_replace(
folder_path: str = '', folder_path: str = '',
caption_file_ext: str = '.caption', caption_file_ext: str = '.caption',
search_text: str = '', search_text: str = '',
replace_text: str = '' replace_text: str = '',
) -> None: ) -> None:
""" """
Find and replace text in caption files within a folder. Find and replace text in caption files within a folder.
@ -374,10 +413,14 @@ def find_replace(
if search_text == '': if search_text == '':
return return
caption_files = [f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)] caption_files = [
f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)
]
for caption_file in caption_files: for caption_file in caption_files:
with open(os.path.join(folder_path, caption_file), 'r', errors='ignore') as f: with open(
os.path.join(folder_path, caption_file), 'r', errors='ignore'
) as f:
content = f.read() content = f.read()
content = content.replace(search_text, replace_text) content = content.replace(search_text, replace_text)
@ -385,6 +428,7 @@ def find_replace(
with open(os.path.join(folder_path, caption_file), 'w') as f: with open(os.path.join(folder_path, caption_file), 'w') as f:
f.write(content) f.write(content)
# def find_replace(folder='', caption_file_ext='.caption', find='', replace=''): # def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
# print('Running caption find/replace') # print('Running caption find/replace')
# if not has_ext_files(folder, caption_file_ext): # if not has_ext_files(folder, caption_file_ext):
@ -477,17 +521,15 @@ def set_pretrained_model_name_or_path_input(
if ( if (
str(pretrained_model_name_or_path) in V1_MODELS str(pretrained_model_name_or_path) in V1_MODELS
or str(pretrained_model_name_or_path) in V2_BASE_MODELS or str(pretrained_model_name_or_path) in V2_BASE_MODELS
or str(pretrained_model_name_or_path) or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS
in V_PARAMETERIZATION_MODELS
): ):
pretrained_model_name_or_path = '' pretrained_model_name_or_path = ''
v2 = False v2 = False
v_parameterization = False v_parameterization = False
return model_list, pretrained_model_name_or_path, v2, v_parameterization return model_list, pretrained_model_name_or_path, v2, v_parameterization
def set_v2_checkbox(
model_list, v2, v_parameterization def set_v2_checkbox(model_list, v2, v_parameterization):
):
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
if str(model_list) in V2_BASE_MODELS: if str(model_list) in V2_BASE_MODELS:
v2 = True v2 = True
@ -504,6 +546,7 @@ def set_v2_checkbox(
return v2, v_parameterization return v2, v_parameterization
def set_model_list( def set_model_list(
model_list, model_list,
pretrained_model_name_or_path, pretrained_model_name_or_path,
@ -538,7 +581,11 @@ def gradio_config():
interactive=True, interactive=True,
) )
button_load_config = gr.Button('Load 💾', elem_id='open_folder') button_load_config = gr.Button('Load 💾', elem_id='open_folder')
config_file_name.change(remove_doublequote, inputs=[config_file_name], outputs=[config_file_name]) config_file_name.change(
remove_doublequote,
inputs=[config_file_name],
outputs=[config_file_name],
)
return ( return (
button_open_config, button_open_config,
button_save_config, button_save_config,
@ -614,8 +661,18 @@ def gradio_source_model():
v_parameterization = gr.Checkbox( v_parameterization = gr.Checkbox(
label='v_parameterization', value=False label='v_parameterization', value=False
) )
v2.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False) v2.change(
v_parameterization.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False) set_v2_checkbox,
inputs=[model_list, v2, v_parameterization],
outputs=[v2, v_parameterization],
show_progress=False,
)
v_parameterization.change(
set_v2_checkbox,
inputs=[model_list, v2, v_parameterization],
outputs=[v2, v_parameterization],
show_progress=False,
)
model_list.change( model_list.change(
set_pretrained_model_name_or_path_input, set_pretrained_model_name_or_path_input,
inputs=[ inputs=[
@ -671,7 +728,9 @@ def gradio_training(
step=1, step=1,
) )
epoch = gr.Number(label='Epoch', value=1, precision=0) epoch = gr.Number(label='Epoch', value=1, precision=0)
save_every_n_epochs = gr.Number(label='Save every N epochs', value=1, precision=0) save_every_n_epochs = gr.Number(
label='Save every N epochs', value=1, precision=0
)
caption_extension = gr.Textbox( caption_extension = gr.Textbox(
label='Caption Extension', label='Caption Extension',
placeholder='(Optional) Extension for caption files. default: .caption', placeholder='(Optional) Extension for caption files. default: .caption',
@ -788,7 +847,7 @@ def run_cmd_training(**kwargs):
if kwargs.get('save_precision') if kwargs.get('save_precision')
else '', else '',
f' --seed="{kwargs.get("seed", "")}"' f' --seed="{kwargs.get("seed", "")}"'
if kwargs.get('seed') != "" if kwargs.get('seed') != ''
else '', else '',
f' --caption_extension="{kwargs.get("caption_extension", "")}"' f' --caption_extension="{kwargs.get("caption_extension", "")}"'
if kwargs.get('caption_extension') if kwargs.get('caption_extension')
@ -964,7 +1023,7 @@ def run_cmd_advanced_training(**kwargs):
f' --noise_offset={float(kwargs.get("noise_offset", 0))}' f' --noise_offset={float(kwargs.get("noise_offset", 0))}'
if not kwargs.get('noise_offset', '') == '' if not kwargs.get('noise_offset', '') == ''
else '', else '',
f' {kwargs.get("additional_parameters", "")}' f' {kwargs.get("additional_parameters", "")}',
] ]
run_cmd = ''.join(options) run_cmd = ''.join(options)
return run_cmd return run_cmd

View File

@ -153,6 +153,14 @@ def gradio_extract_lora_tab():
extract_button.click( extract_button.click(
extract_lora, extract_lora,
inputs=[model_tuned, model_org, save_to, save_precision, dim, v2, conv_dim], inputs=[
model_tuned,
model_org,
save_to,
save_precision,
dim,
v2,
conv_dim,
],
show_progress=False, show_progress=False,
) )

View File

@ -16,12 +16,23 @@ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe'
def extract_lycoris_locon( def extract_lycoris_locon(
db_model, base_model, output_name, device, db_model,
is_v2, mode, linear_dim, conv_dim, base_model,
linear_threshold, conv_threshold, output_name,
linear_ratio, conv_ratio, device,
linear_quantile, conv_quantile, is_v2,
use_sparse_bias, sparsity, disable_cp mode,
linear_dim,
conv_dim,
linear_threshold,
conv_threshold,
linear_ratio,
conv_ratio,
linear_quantile,
conv_quantile,
use_sparse_bias,
sparsity,
disable_cp,
): ):
# Check for caption_text_input # Check for caption_text_input
if db_model == '': if db_model == '':
@ -41,9 +52,7 @@ def extract_lycoris_locon(
msgbox('The provided base model is not a file') msgbox('The provided base model is not a file')
return return
run_cmd = ( run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"'
f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"'
)
if is_v2: if is_v2:
run_cmd += f' --is_v2' run_cmd += f' --is_v2'
run_cmd += f' --device {device}' run_cmd += f' --device {device}'
@ -89,6 +98,7 @@ def extract_lycoris_locon(
# if mode == 'threshold': # if mode == 'threshold':
# return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True) # return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True)
def update_mode(mode): def update_mode(mode):
# Create a list of possible mode values # Create a list of possible mode values
modes = ['fixed', 'threshold', 'ratio', 'quantile'] modes = ['fixed', 'threshold', 'ratio', 'quantile']
@ -104,12 +114,15 @@ def update_mode(mode):
# Return the visibility updates as a tuple # Return the visibility updates as a tuple
return tuple(updates) return tuple(updates)
def gradio_extract_lycoris_locon_tab(): def gradio_extract_lycoris_locon_tab():
with gr.Tab('Extract LyCORIS LoCON'): with gr.Tab('Extract LyCORIS LoCON'):
gr.Markdown( gr.Markdown(
'This utility can extract a LyCORIS LoCon network from a finetuned model.' 'This utility can extract a LyCORIS LoCon network from a finetuned model.'
) )
lora_ext = gr.Textbox(value='*.safetensors', visible=False) # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) lora_ext = gr.Textbox(
value='*.safetensors', visible=False
) # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False)
lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False)
model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False)
model_ext_name = gr.Textbox(value='Model types', visible=False) model_ext_name = gr.Textbox(value='Model types', visible=False)
@ -161,7 +174,10 @@ def gradio_extract_lycoris_locon_tab():
) )
device = gr.Dropdown( device = gr.Dropdown(
label='Device', label='Device',
choices=['cpu', 'cuda',], choices=[
'cpu',
'cuda',
],
value='cuda', value='cuda',
interactive=True, interactive=True,
) )
@ -241,7 +257,9 @@ def gradio_extract_lycoris_locon_tab():
interactive=True, interactive=True,
) )
with gr.Row(): with gr.Row():
use_sparse_bias = gr.Checkbox(label='Use sparse biais', value=False, interactive=True) use_sparse_bias = gr.Checkbox(
label='Use sparse biais', value=False, interactive=True
)
sparsity = gr.Slider( sparsity = gr.Slider(
minimum=0, minimum=0,
maximum=1, maximum=1,
@ -250,24 +268,42 @@ def gradio_extract_lycoris_locon_tab():
step=0.01, step=0.01,
interactive=True, interactive=True,
) )
disable_cp = gr.Checkbox(label='Disable CP decomposition', value=False, interactive=True) disable_cp = gr.Checkbox(
label='Disable CP decomposition', value=False, interactive=True
)
mode.change( mode.change(
update_mode, update_mode,
inputs=[mode], inputs=[mode],
outputs=[ outputs=[
fixed, threshold, ratio, quantile, fixed,
] threshold,
ratio,
quantile,
],
) )
extract_button = gr.Button('Extract LyCORIS LoCon') extract_button = gr.Button('Extract LyCORIS LoCon')
extract_button.click( extract_button.click(
extract_lycoris_locon, extract_lycoris_locon,
inputs=[db_model, base_model, output_name, device, inputs=[
is_v2, mode, linear_dim, conv_dim, db_model,
linear_threshold, conv_threshold, base_model,
linear_ratio, conv_ratio, output_name,
linear_quantile, conv_quantile, device,
use_sparse_bias, sparsity, disable_cp], is_v2,
mode,
linear_dim,
conv_dim,
linear_threshold,
conv_threshold,
linear_ratio,
conv_ratio,
linear_quantile,
conv_quantile,
use_sparse_bias,
sparsity,
disable_cp,
],
show_progress=False, show_progress=False,
) )

View File

@ -27,7 +27,9 @@ def caption_images(
return return
print(f'GIT captioning files in {train_data_dir}...') print(f'GIT captioning files in {train_data_dir}...')
run_cmd = f'.\\venv\\Scripts\\python.exe "finetune/make_captions_by_git.py"' run_cmd = (
f'.\\venv\\Scripts\\python.exe "finetune/make_captions_by_git.py"'
)
if not model_id == '': if not model_id == '':
run_cmd += f' --model_id="{model_id}"' run_cmd += f' --model_id="{model_id}"'
run_cmd += f' --batch_size="{int(batch_size)}"' run_cmd += f' --batch_size="{int(batch_size)}"'

File diff suppressed because it is too large Load Diff

View File

@ -33,12 +33,16 @@ def resize_lora(
if dynamic_method == 'sv_ratio': if dynamic_method == 'sv_ratio':
if float(dynamic_param) < 2: if float(dynamic_param) < 2:
msgbox(f'Dynamic parameter for {dynamic_method} need to be 2 or greater...') msgbox(
f'Dynamic parameter for {dynamic_method} need to be 2 or greater...'
)
return return
if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative': if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative':
if float(dynamic_param) < 0 or float(dynamic_param) > 1: if float(dynamic_param) < 0 or float(dynamic_param) > 1:
msgbox(f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...') msgbox(
f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...'
)
return return
# Check if save_to end with one of the defines extension. If not add .safetensors. # Check if save_to end with one of the defines extension. If not add .safetensors.
@ -108,25 +112,18 @@ def gradio_resize_lora_tab():
with gr.Row(): with gr.Row():
dynamic_method = gr.Dropdown( dynamic_method = gr.Dropdown(
choices=['None', choices=['None', 'sv_ratio', 'sv_fro', 'sv_cumulative'],
'sv_ratio',
'sv_fro',
'sv_cumulative'
],
value='sv_fro', value='sv_fro',
label='Dynamic method', label='Dynamic method',
interactive=True interactive=True,
) )
dynamic_param = gr.Textbox( dynamic_param = gr.Textbox(
label='Dynamic parameter', label='Dynamic parameter',
value='0.9', value='0.9',
interactive=True, interactive=True,
placeholder='Value for the dynamic method selected.' placeholder='Value for the dynamic method selected.',
)
verbose = gr.Checkbox(
label='Verbose',
value=False
) )
verbose = gr.Checkbox(label='Verbose', value=False)
with gr.Row(): with gr.Row():
save_to = gr.Textbox( save_to = gr.Textbox(
label='Save to', label='Save to',
@ -150,7 +147,10 @@ def gradio_resize_lora_tab():
) )
device = gr.Dropdown( device = gr.Dropdown(
label='Device', label='Device',
choices=['cpu', 'cuda',], choices=[
'cpu',
'cuda',
],
value='cuda', value='cuda',
interactive=True, interactive=True,
) )

View File

@ -74,7 +74,7 @@ def run_cmd_sample(
sample_prompts, sample_prompts,
output_dir, output_dir,
): ):
output_dir = os.path.join(output_dir, "sample") output_dir = os.path.join(output_dir, 'sample')
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
@ -85,7 +85,7 @@ def run_cmd_sample(
return run_cmd return run_cmd
# Create the prompt file and get its path # Create the prompt file and get its path
sample_prompts_path = os.path.join(output_dir, "prompt.txt") sample_prompts_path = os.path.join(output_dir, 'prompt.txt')
with open(sample_prompts_path, 'w') as f: with open(sample_prompts_path, 'w') as f:
f.write(sample_prompts) f.write(sample_prompts)

View File

@ -163,7 +163,10 @@ def gradio_svd_merge_lora_tab():
) )
device = gr.Dropdown( device = gr.Dropdown(
label='Device', label='Device',
choices=['cpu', 'cuda',], choices=[
'cpu',
'cuda',
],
value='cuda', value='cuda',
interactive=True, interactive=True,
) )

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,9 @@ from .common_gui import get_folder_path
import os import os
def caption_images(train_data_dir, caption_extension, batch_size, thresh, replace_underscores): def caption_images(
train_data_dir, caption_extension, batch_size, thresh, replace_underscores
):
# Check for caption_text_input # Check for caption_text_input
# if caption_text_input == "": # if caption_text_input == "":
# msgbox("Caption text is missing...") # msgbox("Caption text is missing...")
@ -87,6 +89,12 @@ def gradio_wd14_caption_gui_tab():
caption_button.click( caption_button.click(
caption_images, caption_images,
inputs=[train_data_dir, caption_extension, batch_size, thresh, replace_underscores], inputs=[
train_data_dir,
caption_extension,
batch_size,
thresh,
replace_underscores,
],
show_progress=False, show_progress=False,
) )

View File

@ -4,6 +4,7 @@
# v3.1: Adding captionning of images to utilities # v3.1: Adding captionning of images to utilities
import gradio as gr import gradio as gr
import easygui
import json import json
import math import math
import os import os
@ -26,6 +27,7 @@ from library.common_gui import (
run_cmd_training, run_cmd_training,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist,
) )
from library.dreambooth_folder_creation_gui import ( from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab, gradio_dreambooth_folder_creation_tab,
@ -120,7 +122,8 @@ def save_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -236,7 +239,8 @@ def open_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -342,7 +346,8 @@ def train_model(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
print_only_bool = True if print_only.get('label') == 'True' else False print_only_bool = True if print_only.get('label') == 'True' else False
@ -380,6 +385,9 @@ def train_model(
) )
stop_text_encoder_training_pct = 0 stop_text_encoder_training_pct = 0
if check_if_model_exist(output_name, output_dir, save_model_as):
return
# If string is empty set string to 0. # If string is empty set string to 0.
if text_encoder_lr == '': if text_encoder_lr == '':
text_encoder_lr = 0 text_encoder_lr = 0
@ -492,9 +500,7 @@ def train_model(
) )
return return
run_cmd += f' --network_module=lycoris.kohya' run_cmd += f' --network_module=lycoris.kohya'
run_cmd += ( run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"'
f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"'
)
if LoRA_type == 'LyCORIS/LoHa': if LoRA_type == 'LyCORIS/LoHa':
try: try:
import lycoris import lycoris
@ -504,9 +510,7 @@ def train_model(
) )
return return
run_cmd += f' --network_module=lycoris.kohya' run_cmd += f' --network_module=lycoris.kohya'
run_cmd += ( run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"'
f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"'
)
if LoRA_type == 'Kohya LoCon': if LoRA_type == 'Kohya LoCon':
run_cmd += f' --network_module=networks.lora' run_cmd += f' --network_module=networks.lora'
run_cmd += ( run_cmd += (
@ -596,7 +600,9 @@ def train_model(
) )
if print_only_bool: if print_only_bool:
print('\033[93m\nHere is the trainer command as a reference. It will not be executed:\033[0m\n') print(
'\033[93m\nHere is the trainer command as a reference. It will not be executed:\033[0m\n'
)
print('\033[96m' + run_cmd + '\033[0m\n') print('\033[96m' + run_cmd + '\033[0m\n')
else: else:
print(run_cmd) print(run_cmd)
@ -611,7 +617,9 @@ def train_model(
if not last_dir.is_dir(): if not last_dir.is_dir():
# Copy inference model for v2 if required # Copy inference model for v2 if required
save_inference_file(output_dir, v2, v_parameterization, output_name) save_inference_file(
output_dir, v2, v_parameterization, output_name
)
def lora_tab( def lora_tab(
@ -811,7 +819,12 @@ def lora_tab(
# Show of hide LoCon conv settings depending on LoRA type selection # Show of hide LoCon conv settings depending on LoRA type selection
def LoRA_type_change(LoRA_type): def LoRA_type_change(LoRA_type):
print('LoRA type changed...') print('LoRA type changed...')
if LoRA_type == 'LoCon' or LoRA_type == 'Kohya LoCon' or LoRA_type == 'LyCORIS/LoHa' or LoRA_type == 'LyCORIS/LoCon': if (
LoRA_type == 'LoCon'
or LoRA_type == 'Kohya LoCon'
or LoRA_type == 'LyCORIS/LoHa'
or LoRA_type == 'LyCORIS/LoCon'
):
return gr.Group.update(visible=True) return gr.Group.update(visible=True)
else: else:
return gr.Group.update(visible=False) return gr.Group.update(visible=False)
@ -876,7 +889,8 @@ def lora_tab(
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
noise_offset,additional_parameters, noise_offset,
additional_parameters,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -992,7 +1006,8 @@ def lora_tab(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
] ]
button_open_config.click( button_open_config.click(

View File

@ -26,6 +26,7 @@ from library.common_gui import (
gradio_source_model, gradio_source_model,
# set_legacy_8bitadam, # set_legacy_8bitadam,
update_my_data, update_my_data,
check_if_model_exist,
) )
from library.tensorboard_gui import ( from library.tensorboard_gui import (
gradio_tensorboard, gradio_tensorboard,
@ -110,7 +111,8 @@ def save_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -222,7 +224,8 @@ def open_configuration(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -316,7 +319,8 @@ def train_model(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -350,6 +354,9 @@ def train_model(
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
if check_if_model_exist(output_name, output_dir, save_model_as):
return
# Get a list of all subfolders in train_data_dir # Get a list of all subfolders in train_data_dir
subfolders = [ subfolders = [
f f
@ -761,7 +768,8 @@ def ti_tab(
bucket_reso_steps, bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_dropout_rate, caption_dropout_rate,
noise_offset,additional_parameters, noise_offset,
additional_parameters,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -866,7 +874,8 @@ def ti_tab(
sample_every_n_steps, sample_every_n_steps,
sample_every_n_epochs, sample_every_n_epochs,
sample_sampler, sample_sampler,
sample_prompts,additional_parameters, sample_prompts,
additional_parameters,
] ]
button_open_config.click( button_open_config.click(

View File

@ -7,6 +7,7 @@ import argparse
import itertools import itertools
import math import math
import os import os
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -43,12 +44,16 @@ def train(args):
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"] ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else: else:
user_config = { user_config = {
"datasets": [{ "datasets": [
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
}] ]
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
@ -62,15 +67,20 @@ def train(args):
return return
if cache_latents: if cache_latents:
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
print(f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong")
print( print(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です") f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
)
print(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
)
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator, unwrap_model = train_util.prepare_accelerator(args)
@ -92,7 +102,7 @@ def train(args):
save_stable_diffusion_format = load_stable_diffusion_format save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors use_safetensors = args.use_safetensors
else: else:
save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
@ -129,7 +139,7 @@ def train(args):
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.") print("prepare optimizer, data loader etc.")
if train_text_encoder: if train_text_encoder:
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters())) trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
else: else:
trainable_params = unet.parameters() trainable_params = unet.parameters()
@ -139,7 +149,13 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
@ -150,13 +166,13 @@ def train(args):
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
num_training_steps=args.max_train_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16: if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.") print("enable full fp16 training.")
unet.to(weight_dtype) unet.to(weight_dtype)
text_encoder.to(weight_dtype) text_encoder.to(weight_dtype)
@ -164,7 +180,8 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder: if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler) unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else: else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
@ -201,8 +218,9 @@ def train(args):
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0 global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", noise_scheduler = DDPMScheduler(
num_train_timesteps=1000, clip_sample=False) beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("dreambooth") accelerator.init_trackers("dreambooth")
@ -247,7 +265,8 @@ def train(args):
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
input_ids = batch["input_ids"].to(accelerator.device) input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states( encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@ -277,7 +296,7 @@ def train(args):
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
if train_text_encoder: if train_text_encoder:
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
else: else:
params_to_clip = unet.parameters() params_to_clip = unet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
@ -291,13 +310,17 @@ def train(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
if epoch == 0: if epoch == 0:
@ -321,8 +344,20 @@ def train(args):
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, train_util.save_sd_model_on_epoch_end(
save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) args,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
@ -340,12 +375,13 @@ def train(args):
if is_main_process: if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, train_util.save_sd_model_on_train_end(
save_dtype, epoch, global_step, text_encoder, unet, vae) args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.") print("model saved.")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@ -355,10 +391,19 @@ if __name__ == '__main__':
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
parser.add_argument("--no_token_padding", action="store_true", parser.add_argument(
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作") "--no_token_padding",
parser.add_argument("--stop_text_encoder_training", type=int, default=None, action="store_true",
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない") help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作",
)
parser.add_argument(
"--stop_text_encoder_training",
type=int,
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args) train(args)

View File

@ -7,6 +7,7 @@ import os
import random import random
import time import time
import json import json
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -41,7 +42,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr'] logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
return logs return logs
@ -69,24 +70,31 @@ def train(args):
ignored = ["train_data_dir", "reg_data_dir", "in_json"] ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else: else:
if use_dreambooth_method: if use_dreambooth_method:
print("Use DreamBooth method.") print("Use DreamBooth method.")
user_config = { user_config = {
"datasets": [{ "datasets": [
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
}] ]
} }
else: else:
print("Train with captions.") print("Train with captions.")
user_config = { user_config = {
"datasets": [{ "datasets": [
"subsets": [{ {
"subsets": [
{
"image_dir": args.train_data_dir, "image_dir": args.train_data_dir,
"metadata_file": args.in_json, "metadata_file": args.in_json,
}] }
}] ]
}
]
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
@ -96,11 +104,14 @@ def train(args):
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります") print(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return return
if cache_latents: if cache_latents:
assert train_dataset_group.is_latent_cacheable( assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する # acceleratorを準備する
@ -136,6 +147,7 @@ def train(args):
# prepare network # prepare network
import sys import sys
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))
print("import network module:", args.network_module) print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module) network_module = importlib.import_module(args.network_module)
@ -143,7 +155,7 @@ def train(args):
net_kwargs = {} net_kwargs = {}
if args.network_args is not None: if args.network_args is not None:
for net_arg in args.network_args: for net_arg in args.network_args:
key, value = net_arg.split('=') key, value = net_arg.split("=")
net_kwargs[key] = value net_kwargs[key] = value
# if a new network is added in future, add if ~ then blocks for each network (;'∀') # if a new network is added in future, add if ~ then blocks for each network (;'∀')
@ -174,7 +186,13 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
@ -183,29 +201,31 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16: if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.") print("enable full fp16 training.")
network.to(weight_dtype) network.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if train_unet and train_text_encoder: if train_unet and train_text_encoder:
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler) unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
elif train_unet: elif train_unet:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler) unet, network, optimizer, train_dataloader, lr_scheduler
)
elif train_text_encoder: elif train_text_encoder:
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, network, optimizer, train_dataloader, lr_scheduler) text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
else: else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
network, optimizer, train_dataloader, lr_scheduler)
unet.requires_grad_(False) unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype)
@ -371,10 +391,7 @@ def train(args):
i += 1 i += 1
image_dir_or_metadata_file = v image_dir_or_metadata_file = v
dataset_dirs_info[image_dir_or_metadata_file] = { dataset_dirs_info[image_dir_or_metadata_file] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
"n_repeats": subset.num_repeats,
"img_count": subset.img_count
}
dataset_metadata["subsets"] = subsets_metadata dataset_metadata["subsets"] = subsets_metadata
datasets_metadata.append(dataset_metadata) datasets_metadata.append(dataset_metadata)
@ -393,8 +410,9 @@ def train(args):
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
else: else:
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
assert len( assert (
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" len(train_dataset_group.datasets) == 1
), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
dataset = train_dataset_group.datasets[0] dataset = train_dataset_group.datasets[0]
@ -403,18 +421,16 @@ def train(args):
if use_dreambooth_method: if use_dreambooth_method:
for subset in dataset.subsets: for subset in dataset.subsets:
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
info[os.path.basename(subset.image_dir)] = { info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
"n_repeats": subset.num_repeats,
"img_count": subset.img_count
}
else: else:
for subset in dataset.subsets: for subset in dataset.subsets:
dataset_dirs_info[os.path.basename(subset.metadata_file)] = { dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
"n_repeats": subset.num_repeats, "n_repeats": subset.num_repeats,
"img_count": subset.img_count "img_count": subset.img_count,
} }
metadata.update({ metadata.update(
{
"ss_batch_size_per_device": args.train_batch_size, "ss_batch_size_per_device": args.train_batch_size,
"ss_total_batch_size": total_batch_size, "ss_total_batch_size": total_batch_size,
"ss_resolution": args.resolution, "ss_resolution": args.resolution,
@ -431,7 +447,8 @@ def train(args):
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(dataset.tag_frequency), "ss_tag_frequency": json.dumps(dataset.tag_frequency),
"ss_bucket_info": json.dumps(dataset.bucket_info), "ss_bucket_info": json.dumps(dataset.bucket_info),
}) }
)
# add extra args # add extra args
if args.network_args: if args.network_args:
@ -468,8 +485,9 @@ def train(args):
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0 global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", noise_scheduler = DDPMScheduler(
num_train_timesteps=1000, clip_sample=False) beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("network_train") accelerator.init_trackers("network_train")
@ -547,7 +565,9 @@ def train(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
)
current_loss = loss.detach().item() current_loss = loss.detach().item()
if epoch == 0: if epoch == 0:
@ -577,14 +597,14 @@ def train(args):
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func(): def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
metadata["ss_training_finished_at"] = str(time.time()) metadata["ss_training_finished_at"] = str(time.time())
print(f"saving checkpoint: {ckpt_file}") print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
def remove_old_func(old_epoch_no): def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file): if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}") print(f"removing old checkpoint: {old_ckpt_file}")
@ -616,7 +636,7 @@ def train(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + '.' + args.save_model_as ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}") print(f"save trained model to {ckpt_file}")
@ -624,7 +644,7 @@ def train(args):
print("model saved.") print("model saved.")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@ -633,27 +653,41 @@ if __name__ == '__main__':
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], parser.add_argument(
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors") "--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--network_weights", type=str, default=None, parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
help="pretrained weights for network / 学習するネットワークの初期重み") parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール")
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') parser.add_argument(
parser.add_argument("--network_dim", type=int, default=None, "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)"
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') )
parser.add_argument("--network_alpha", type=float, default=1, parser.add_argument(
help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定') "--network_alpha",
parser.add_argument("--network_args", type=str, default=None, nargs='*', type=float,
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') default=1,
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定",
)
parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
parser.add_argument("--network_train_text_encoder_only", action="store_true", parser.add_argument(
help="only training Text Encoder part / Text Encoder関連部分のみ学習する") "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する"
parser.add_argument("--training_comment", type=str, default=None, )
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") parser.add_argument(
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
)
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args) train(args)

View File

@ -3,6 +3,7 @@ import argparse
import gc import gc
import math import math
import os import os
import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -104,14 +105,17 @@ def train(args):
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print( print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}") f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
)
else: else:
init_token_ids = None init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token # add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
num_added_tokens = tokenizer.add_tokens(token_strings) num_added_tokens = tokenizer.add_tokens(token_strings)
assert num_added_tokens == args.num_vectors_per_token, f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" assert (
num_added_tokens == args.num_vectors_per_token
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings) token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}") print(f"tokens are added: {token_ids}")
@ -132,7 +136,8 @@ def train(args):
if args.weights is not None: if args.weights is not None:
embeddings = load_weights(args.weights) embeddings = load_weights(args.weights)
assert len(token_ids) == len( assert len(token_ids) == len(
embeddings), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size()) # print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids, embeddings): for token_id, embedding in zip(token_ids, embeddings):
token_embeds[token_id] = embedding token_embeds[token_id] = embedding
@ -148,25 +153,33 @@ def train(args):
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"] ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else: else:
use_dreambooth_method = args.in_json is None use_dreambooth_method = args.in_json is None
if use_dreambooth_method: if use_dreambooth_method:
print("Use DreamBooth method.") print("Use DreamBooth method.")
user_config = { user_config = {
"datasets": [{ "datasets": [
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
}] ]
} }
else: else:
print("Train with captions.") print("Train with captions.")
user_config = { user_config = {
"datasets": [{ "datasets": [
"subsets": [{ {
"subsets": [
{
"image_dir": args.train_data_dir, "image_dir": args.train_data_dir,
"metadata_file": args.in_json, "metadata_file": args.in_json,
}] }
}] ]
}
]
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
@ -202,7 +215,9 @@ def train(args):
return return
if cache_latents: if cache_latents:
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@ -232,7 +247,13 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
@ -240,13 +261,12 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler) text_encoder, optimizer, train_dataloader, lr_scheduler
)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates)) # print(len(index_no_updates), torch.sum(index_no_updates))
@ -302,8 +322,9 @@ def train(args):
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0 global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", noise_scheduler = DDPMScheduler(
num_train_timesteps=1000, clip_sample=False) beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion") accelerator.init_trackers("textual_inversion")
@ -373,21 +394,26 @@ def train(args):
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad(): with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates] unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
index_no_updates
]
# Checks if the accelerator has performed an optimization step behind the scenes # Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
train_util.sample_images(accelerator, args, None, global_step, accelerator.device, train_util.sample_images(
vae, tokenizer, text_encoder, unet, prompt_replacement) accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
loss_total += current_loss loss_total += current_loss
@ -410,13 +436,13 @@ def train(args):
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func(): def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}") print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype) save_weights(ckpt_file, updated_embs, save_dtype)
def remove_old_func(old_epoch_no): def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file): if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}") print(f"removing old checkpoint: {old_ckpt_file}")
@ -426,8 +452,9 @@ def train(args):
if saving and args.save_state: if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, train_util.sample_images(
vae, tokenizer, text_encoder, unet, prompt_replacement) accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
# end of epoch # end of epoch
@ -448,7 +475,7 @@ def train(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + '.' + args.save_model_as ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}") print(f"save trained model to {ckpt_file}")
@ -465,27 +492,29 @@ def save_weights(file, updated_embs, save_dtype):
v = v.detach().clone().to("cpu").to(save_dtype) v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v state_dict[key] = v
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file from safetensors.torch import save_file
save_file(state_dict, file) save_file(state_dict, file)
else: else:
torch.save(state_dict, file) # can be loaded in Web UI torch.save(state_dict, file) # can be loaded in Web UI
def load_weights(file): def load_weights(file):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file
data = load_file(file) data = load_file(file)
else: else:
# compatible to Web UI's file format # compatible to Web UI's file format
data = torch.load(file, map_location='cpu') data = torch.load(file, map_location="cpu")
if type(data) != dict: if type(data) != dict:
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
if 'string_to_param' in data: # textual inversion embeddings if "string_to_param" in data: # textual inversion embeddings
data = data['string_to_param'] data = data["string_to_param"]
if hasattr(data, '_parameters'): # support old PyTorch? if hasattr(data, "_parameters"): # support old PyTorch?
data = getattr(data, '_parameters') data = getattr(data, "_parameters")
emb = next(iter(data.values())) emb = next(iter(data.values()))
if type(emb) != torch.Tensor: if type(emb) != torch.Tensor:
@ -497,7 +526,7 @@ def load_weights(file):
return emb return emb
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
@ -506,21 +535,37 @@ if __name__ == '__main__':
train_util.add_optimizer_arguments(parser) train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser) config_util.add_config_arguments(parser)
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], parser.add_argument(
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt") "--save_model_as",
type=str,
default="pt",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
help="embedding weights to initialize / 学習するネットワークの初期重み") parser.add_argument(
parser.add_argument("--num_vectors_per_token", type=int, default=1, "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
help='number of vectors per token / トークンに割り当てるembeddingsの要素数') )
parser.add_argument("--token_string", type=str, default=None, parser.add_argument(
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること") "--token_string",
parser.add_argument("--init_word", type=str, default=None, type=str,
help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") default=None,
parser.add_argument("--use_object_template", action='store_true', help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する") )
parser.add_argument("--use_style_template", action='store_true', parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する") parser.add_argument(
"--use_object_template",
action="store_true",
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
)
parser.add_argument(
"--use_style_template",
action="store_true",
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
)
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args) train(args)