From ccae80186a95854a46098a64126d7cbd7665b713 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 20 Mar 2023 08:47:00 -0400 Subject: [PATCH] Update to latest sd-script code --- README.md | 16 + dreambooth_gui.py | 19 +- fine_tune.py | 623 ++-- finetune/merge_captions_to_metadata.py | 5 +- finetune/merge_dd_tags_to_metadata.py | 5 +- finetune/tag_images_by_wd14_tagger.py | 3 +- finetune_gui.py | 25 +- library/basic_caption_gui.py | 2 +- library/common_gui.py | 137 +- library/extract_lora_gui.py | 10 +- library/extract_lycoris_locon_gui.py | 82 +- library/git_caption_gui.py | 4 +- library/lpw_stable_diffusion.py | 1179 +++++++ library/resize_lora_gui.py | 32 +- library/sampler_gui.py | 10 +- library/svd_merge_lora_gui.py | 5 +- library/train_util.py | 4403 +++++++++++++----------- library/wd14_caption_gui.py | 14 +- lora_gui.py | 61 +- textual_inversion_gui.py | 25 +- train_db.py | 625 ++-- train_network.py | 1186 +++---- train_textual_inversion.py | 847 ++--- 23 files changed, 5678 insertions(+), 3640 deletions(-) create mode 100644 library/lpw_stable_diffusion.py diff --git a/README.md b/README.md index cce5fe8..636c5a6 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,9 @@ If you run on Linux and would like to use the GUI, there is now a port of it as ## Installation +### Runpod +Follow the instructions found in this discussion: https://github.com/bmaltais/kohya_ss/discussions/379 + ### Ubuntu In the terminal, run @@ -189,6 +192,19 @@ This will store your a backup file with your current locally installed pip packa ## Change History +* 2023/03/19 (v21.3.0) + - Add a function to load training config with `.toml` to each training script. Thanks to Linaqruf for this great contribution! + - Specify `.toml` file with `--config_file`. `.toml` file has `key=value` entries. Keys are same as command line options. See [#241](https://github.com/kohya-ss/sd-scripts/pull/241) for details. + - All sub-sections are combined to a single dictionary (the section names are ignored.) + - Omitted arguments are the default values for command line arguments. + - Command line args override the arguments in `.toml`. + - With `--output_config` option, you can output current command line options to the `.toml` specified with`--config_file`. Please use as a template. + - Add `--lr_scheduler_type` and `--lr_scheduler_args` arguments for custom LR scheduler to each training script. Thanks to Isotr0py! [#271](https://github.com/kohya-ss/sd-scripts/pull/271) + - Same as the optimizer. + - Add sample image generation with weight and no length limit. Thanks to mio2333! [#288](https://github.com/kohya-ss/sd-scripts/pull/288) + - `( )`, `(xxxx:1.2)` and `[ ]` can be used. + - Fix exception on training model in diffusers format with `train_network.py` Thanks to orenwang! [#290](https://github.com/kohya-ss/sd-scripts/pull/290) + - Add warning if you are about to overwrite an existing model: https://github.com/bmaltais/kohya_ss/issues/404 * 2023/03/19 (v21.2.5): - Fix basic captioning logic - Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1. diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 7d24b70..6053d44 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -26,6 +26,7 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, + check_if_model_exist, ) from library.tensorboard_gui import ( gradio_tensorboard, @@ -104,7 +105,8 @@ def save_configuration( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -210,15 +212,16 @@ def open_configuration( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): # Get list of function parameters and values parameters = list(locals().items()) - + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - + if ask_for_file: file_path = get_file_path(file_path) @@ -298,7 +301,8 @@ def train_model( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -321,6 +325,9 @@ def train_model( msgbox('Output folder path is missing') return + if check_if_model_exist(output_name, output_dir, save_model_as): + return + # Get a list of all subfolders in train_data_dir subfolders = [ f @@ -787,7 +794,7 @@ def dreambooth_tab( outputs=[config_file_name] + settings_list, show_progress=False, ) - + button_load_config.click( open_configuration, inputs=[dummy_db_false, config_file_name] + settings_list, diff --git a/fine_tune.py b/fine_tune.py index 1255759..d927bd7 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -5,6 +5,7 @@ import argparse import gc import math import os +import toml from tqdm import tqdm import torch @@ -15,351 +16,391 @@ from diffusers import DDPMScheduler import library.train_util as train_util import library.config_util as config_util from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, + ConfigSanitizer, + BlueprintGenerator, ) + def collate_fn(examples): - return examples[0] + return examples[0] def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - cache_latents = args.cache_latents + cache_latents = args.cache_latents - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenizer = train_util.load_tokenizer(args) - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) - else: - user_config = { - "datasets": [{ - "subsets": [{ - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - }] - }] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") - return - - if cache_latents: - assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) - - # verify load/save model formats - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path - - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - - # Diffusers版のxformers使用フラグを設定する関数 - def set_diffusers_xformers_flag(model, valid): - # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう - # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) - # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか - # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) - - # Recursively walk through all the children. - # Any children which exposes the set_use_memory_efficient_attention_xformers method - # gets the message - def fn_recursive_set_mem_eff(module: torch.nn.Module): - if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid) - - for child in module.children(): - fn_recursive_set_mem_eff(child) - - fn_recursive_set_mem_eff(model) - - # モデルに xformers とか memory efficient attention を組み込む - if args.diffusers_xformers: - print("Use xformers by Diffusers") - set_diffusers_xformers_flag(unet, True) - else: - # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある - print("Disable Diffusers' xformers") - set_diffusers_xformers_flag(unet, False) - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # 学習を準備する:モデルを適切な状態にする - training_models = [] - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - training_models.append(unet) - - if args.train_text_encoder: - print("enable text encoder training") - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - training_models.append(text_encoder) - else: - text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) # text encoderは学習しない - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - text_encoder.train() # required for gradient_checkpointing + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) else: - text_encoder.eval() + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - for m in training_models: - m.requires_grad_(True) - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder.to(weight_dtype) + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path - # acceleratorがなんかよろしくやってくれるらしい - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) + # Diffusers版のxformers使用フラグを設定する関数 + def set_diffusers_xformers_flag(model, valid): + # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう + # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) + # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか + # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) - # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + for child in module.children(): + fn_recursive_set_mem_eff(child) - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + fn_recursive_set_mem_eff(model) - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 + # モデルに xformers とか memory efficient attention を組み込む + if args.diffusers_xformers: + print("Use xformers by Diffusers") + set_diffusers_xformers_flag(unet, True) + else: + # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある + print("Disable Diffusers' xformers") + set_diffusers_xformers_flag(unet, False) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000, clip_sample=False) + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - if accelerator.is_main_process: - accelerator.init_trackers("finetuning") + # 学習を準備する:モデルを適切な状態にする + training_models = [] + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) - for epoch in range(num_train_epochs): - print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + if args.train_text_encoder: + print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + training_models.append(text_encoder) + else: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) # text encoderは学習しない + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + text_encoder.train() # required for gradient_checkpointing + else: + text_encoder.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) for m in training_models: - m.train() + m.requires_grad_(True) + params = [] + for m in training_models: + params.extend(m.parameters()) + params_to_optimize = params - loss_total = 0 - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - with torch.set_grad_enabled(args.train_text_encoder): - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collate_fn, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 - train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) - current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value - logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] - accelerator.log(logs, step=global_step) + if accelerator.is_main_process: + accelerator.init_trackers("finetuning") - # TODO moving averageにする - loss_total += current_loss - avr_loss = loss_total / (step+1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset_group.set_current_epoch(epoch + 1) - if global_step >= args.max_train_steps: - break + for m in training_models: + m.train() - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch+1) + loss_total = 0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] - accelerator.wait_for_everyone() + with torch.set_grad_enabled(args.train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) - if args.save_every_n_epochs is not None: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - is_main_process = accelerator.is_main_process - if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() - accelerator.end_training() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - if args.save_state: - train_util.save_state_on_train_end(args, accelerator) + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - del accelerator # この後メモリを使うのでこれは消す + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - if is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, global_step, text_encoder, unet, vae) - print("model saved.") + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + # TODO moving averageにする + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end( + args, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + print("model saved.") -if __name__ == '__main__': - parser = argparse.ArgumentParser() +if __name__ == "__main__": + parser = argparse.ArgumentParser() - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True, True) - train_util.add_training_arguments(parser, False) - train_util.add_sd_saving_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) - parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers / Diffusersでxformersを使用する') - parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") - args = parser.parse_args() - train(args) + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/finetune/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py index cbc5033..491e459 100644 --- a/finetune/merge_captions_to_metadata.py +++ b/finetune/merge_captions_to_metadata.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List from tqdm import tqdm import library.train_util as train_util - +import os def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" @@ -29,6 +29,9 @@ def main(args): caption_path = image_path.with_suffix(args.caption_extension) caption = caption_path.read_text(encoding='utf-8').strip() + if not os.path.exists(caption_path): + caption_path = os.path.join(image_path, args.caption_extension) + image_key = str(image_path) if args.full_path else image_path.stem if image_key not in metadata: metadata[image_key] = {} diff --git a/finetune/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py index 4285feb..8823a9c 100644 --- a/finetune/merge_dd_tags_to_metadata.py +++ b/finetune/merge_dd_tags_to_metadata.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List from tqdm import tqdm import library.train_util as train_util - +import os def main(args): assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" @@ -29,6 +29,9 @@ def main(args): tags_path = image_path.with_suffix(args.caption_extension) tags = tags_path.read_text(encoding='utf-8').strip() + if not os.path.exists(tags_path): + tags_path = os.path.join(image_path, args.caption_extension) + image_key = str(image_path) if args.full_path else image_path.stem if image_key not in metadata: metadata[image_key] = {} diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 933117c..609b8c5 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -125,7 +125,7 @@ def main(args): tag_text = "" for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで if p >= args.thresh and i < len(tags): - tag_text += ", " + (tags[i].replace("_", " ") if args.replace_underscores else tags[i]) + tag_text += ", " + tags[i] if len(tag_text) > 0: tag_text = tag_text[2:] # 最初の ", " を消す @@ -190,7 +190,6 @@ if __name__ == '__main__': help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") parser.add_argument("--debug", action="store_true", help="debug mode") - parser.add_argument("--replace_underscores", action="store_true", help="replace underscores in tags with spaces / タグのアンダースコアをスペースに置き換える") args = parser.parse_args() diff --git a/finetune_gui.py b/finetune_gui.py index 59dffd8..b56ab05 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -20,6 +20,7 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, + check_if_model_exist, ) from library.tensorboard_gui import ( gradio_tensorboard, @@ -102,7 +103,8 @@ def save_configuration( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -214,15 +216,16 @@ def open_configuration( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): # Get list of function parameters and values parameters = list(locals().items()) - + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - + if ask_for_file: file_path = get_file_path(file_path) @@ -308,8 +311,12 @@ def train_model( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): + if check_if_model_exist(output_name, output_dir, save_model_as): + return + # create caption json file if generate_caption_database: if not os.path.exists(train_dir): @@ -677,7 +684,8 @@ def finetune_tab(): bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, - noise_offset,additional_parameters, + noise_offset, + additional_parameters, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -770,7 +778,8 @@ def finetune_tab(): sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ] button_run.click(train_model, inputs=settings_list) @@ -781,7 +790,7 @@ def finetune_tab(): outputs=[config_file_name] + settings_list, show_progress=False, ) - + button_load_config.click( open_configuration, inputs=[dummy_db_false, config_file_name] + settings_list, diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index 1672880..669eaa0 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -122,7 +122,7 @@ def gradio_basic_caption_gui_tab(): label='Replacement text', placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing', interactive=True, - ) + ) caption_button = gr.Button('Caption images') caption_button.click( caption_images, diff --git a/library/common_gui.py b/library/common_gui.py index 05accda..604576e 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -1,7 +1,7 @@ from tkinter import filedialog, Tk import os import gradio as gr -from easygui import msgbox +import easygui import shutil folder_symbol = '\U0001f4c2' # 📂 @@ -31,6 +31,34 @@ V1_MODELS = [ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS +def check_if_model_exist(output_name, output_dir, save_model_as): + if save_model_as in ['diffusers', 'diffusers_safetendors']: + ckpt_folder = os.path.join(output_dir, output_name) + if os.path.isdir(ckpt_folder): + msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?' + if not easygui.ynbox(msg, 'Overwrite Existing Model?'): + print( + 'Aborting training due to existing model with same name...' + ) + return True + elif save_model_as in ['ckpt', 'safetensors']: + ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as) + if os.path.isfile(ckpt_file): + msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?' + if not easygui.ynbox(msg, 'Overwrite Existing Model?'): + print( + 'Aborting training due to existing model with same name...' + ) + return True + else: + print( + 'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...' + ) + return False + + return False + + def update_my_data(my_data): # Update optimizer based on use_8bit_adam flag use_8bit_adam = my_data.get('use_8bit_adam', False) @@ -38,11 +66,16 @@ def update_my_data(my_data): my_data['optimizer'] = 'AdamW8bit' elif 'optimizer' not in my_data: my_data['optimizer'] = 'AdamW' - + # Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model model_list = my_data.get('model_list', []) - pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '') - if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS: + pretrained_model_name_or_path = my_data.get( + 'pretrained_model_name_or_path', '' + ) + if ( + not model_list + or pretrained_model_name_or_path not in ALL_PRESET_MODELS + ): my_data['model_list'] = 'custom' # Convert epoch and save_every_n_epochs values to int if they are strings @@ -78,7 +111,7 @@ def update_my_data(my_data): # # If Pretrained model name or path is not one of the preset models then set the preset_model to custom # if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS: # my_data['model_list'] = 'custom' - + # # Fix old config files that contain epoch as str instead of int # for key in ['epoch', 'save_every_n_epochs']: # value = my_data.get(key, -1) @@ -87,10 +120,10 @@ def update_my_data(my_data): # my_data[key] = int(value) # else: # my_data[key] = -1 - + # if my_data.get('LoRA_type', 'Standard') == 'LoCon': # my_data['LoRA_type'] = 'LyCORIS/LoCon' - + # return my_data @@ -265,11 +298,11 @@ def get_saveasfilename_path( def add_pre_postfix( - folder: str = '', - prefix: str = '', - postfix: str = '', - caption_file_ext: str = '.caption' - ) -> None: + folder: str = '', + prefix: str = '', + postfix: str = '', + caption_file_ext: str = '.caption', +) -> None: """ Add prefix and/or postfix to the content of caption files within a folder. If no caption files are found, create one with the requested prefix and/or postfix. @@ -285,7 +318,9 @@ def add_pre_postfix( return image_extensions = ('.jpg', '.jpeg', '.png', '.webp') - image_files = [f for f in os.listdir(folder) if f.lower().endswith(image_extensions)] + image_files = [ + f for f in os.listdir(folder) if f.lower().endswith(image_extensions) + ] for image_file in image_files: caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext @@ -303,7 +338,10 @@ def add_pre_postfix( prefix_separator = ' ' if prefix else '' postfix_separator = ' ' if postfix else '' - f.write(f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}') + f.write( + f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}' + ) + # def add_pre_postfix( # folder='', prefix='', postfix='', caption_file_ext='.caption' @@ -335,11 +373,11 @@ def add_pre_postfix( def has_ext_files(folder_path: str, file_extension: str) -> bool: """ Check if there are any files with the specified extension in the given folder. - + Args: folder_path (str): Path to the folder containing files. file_extension (str): Extension of the files to look for. - + Returns: bool: True if files with the specified extension are found, False otherwise. """ @@ -348,15 +386,16 @@ def has_ext_files(folder_path: str, file_extension: str) -> bool: return True return False + def find_replace( - folder_path: str = '', - caption_file_ext: str = '.caption', - search_text: str = '', - replace_text: str = '' - ) -> None: + folder_path: str = '', + caption_file_ext: str = '.caption', + search_text: str = '', + replace_text: str = '', +) -> None: """ Find and replace text in caption files within a folder. - + Args: folder_path (str, optional): Path to the folder containing caption files. caption_file_ext (str, optional): Extension of the caption files. @@ -364,7 +403,7 @@ def find_replace( replace_text (str, optional): Text to replace the search text with. """ print('Running caption find/replace') - + if not has_ext_files(folder_path, caption_file_ext): msgbox( f'No files with extension {caption_file_ext} were found in {folder_path}...' @@ -374,10 +413,14 @@ def find_replace( if search_text == '': return - caption_files = [f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)] - + caption_files = [ + f for f in os.listdir(folder_path) if f.endswith(caption_file_ext) + ] + for caption_file in caption_files: - with open(os.path.join(folder_path, caption_file), 'r', errors='ignore') as f: + with open( + os.path.join(folder_path, caption_file), 'r', errors='ignore' + ) as f: content = f.read() content = content.replace(search_text, replace_text) @@ -385,6 +428,7 @@ def find_replace( with open(os.path.join(folder_path, caption_file), 'w') as f: f.write(content) + # def find_replace(folder='', caption_file_ext='.caption', find='', replace=''): # print('Running caption find/replace') # if not has_ext_files(folder, caption_file_ext): @@ -477,17 +521,15 @@ def set_pretrained_model_name_or_path_input( if ( str(pretrained_model_name_or_path) in V1_MODELS or str(pretrained_model_name_or_path) in V2_BASE_MODELS - or str(pretrained_model_name_or_path) - in V_PARAMETERIZATION_MODELS + or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS ): pretrained_model_name_or_path = '' v2 = False v_parameterization = False return model_list, pretrained_model_name_or_path, v2, v_parameterization -def set_v2_checkbox( - model_list, v2, v_parameterization -): + +def set_v2_checkbox(model_list, v2, v_parameterization): # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list if str(model_list) in V2_BASE_MODELS: v2 = True @@ -504,6 +546,7 @@ def set_v2_checkbox( return v2, v_parameterization + def set_model_list( model_list, pretrained_model_name_or_path, @@ -515,7 +558,7 @@ def set_model_list( model_list = 'custom' else: model_list = pretrained_model_name_or_path - + return model_list, v2, v_parameterization @@ -538,7 +581,11 @@ def gradio_config(): interactive=True, ) button_load_config = gr.Button('Load 💾', elem_id='open_folder') - config_file_name.change(remove_doublequote, inputs=[config_file_name], outputs=[config_file_name]) + config_file_name.change( + remove_doublequote, + inputs=[config_file_name], + outputs=[config_file_name], + ) return ( button_open_config, button_save_config, @@ -614,8 +661,18 @@ def gradio_source_model(): v_parameterization = gr.Checkbox( label='v_parameterization', value=False ) - v2.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False) - v_parameterization.change(set_v2_checkbox, inputs=[model_list, v2, v_parameterization], outputs=[v2, v_parameterization],show_progress=False) + v2.change( + set_v2_checkbox, + inputs=[model_list, v2, v_parameterization], + outputs=[v2, v_parameterization], + show_progress=False, + ) + v_parameterization.change( + set_v2_checkbox, + inputs=[model_list, v2, v_parameterization], + outputs=[v2, v_parameterization], + show_progress=False, + ) model_list.change( set_pretrained_model_name_or_path_input, inputs=[ @@ -671,7 +728,9 @@ def gradio_training( step=1, ) epoch = gr.Number(label='Epoch', value=1, precision=0) - save_every_n_epochs = gr.Number(label='Save every N epochs', value=1, precision=0) + save_every_n_epochs = gr.Number( + label='Save every N epochs', value=1, precision=0 + ) caption_extension = gr.Textbox( label='Caption Extension', placeholder='(Optional) Extension for caption files. default: .caption', @@ -788,7 +847,7 @@ def run_cmd_training(**kwargs): if kwargs.get('save_precision') else '', f' --seed="{kwargs.get("seed", "")}"' - if kwargs.get('seed') != "" + if kwargs.get('seed') != '' else '', f' --caption_extension="{kwargs.get("caption_extension", "")}"' if kwargs.get('caption_extension') @@ -807,7 +866,7 @@ def run_cmd_training(**kwargs): def gradio_advanced_training(): with gr.Row(): additional_parameters = gr.Textbox( - label='Additional parameters', + label='Additional parameters', placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"', ) with gr.Row(): @@ -964,7 +1023,7 @@ def run_cmd_advanced_training(**kwargs): f' --noise_offset={float(kwargs.get("noise_offset", 0))}' if not kwargs.get('noise_offset', '') == '' else '', - f' {kwargs.get("additional_parameters", "")}' + f' {kwargs.get("additional_parameters", "")}', ] run_cmd = ''.join(options) return run_cmd diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index 7c991a2..6088805 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -153,6 +153,14 @@ def gradio_extract_lora_tab(): extract_button.click( extract_lora, - inputs=[model_tuned, model_org, save_to, save_precision, dim, v2, conv_dim], + inputs=[ + model_tuned, + model_org, + save_to, + save_precision, + dim, + v2, + conv_dim, + ], show_progress=False, ) diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py index 43de7a2..13575bb 100644 --- a/library/extract_lycoris_locon_gui.py +++ b/library/extract_lycoris_locon_gui.py @@ -16,12 +16,23 @@ PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' def extract_lycoris_locon( - db_model, base_model, output_name, device, - is_v2, mode, linear_dim, conv_dim, - linear_threshold, conv_threshold, - linear_ratio, conv_ratio, - linear_quantile, conv_quantile, - use_sparse_bias, sparsity, disable_cp + db_model, + base_model, + output_name, + device, + is_v2, + mode, + linear_dim, + conv_dim, + linear_threshold, + conv_threshold, + linear_ratio, + conv_ratio, + linear_quantile, + conv_quantile, + use_sparse_bias, + sparsity, + disable_cp, ): # Check for caption_text_input if db_model == '': @@ -41,9 +52,7 @@ def extract_lycoris_locon( msgbox('The provided base model is not a file') return - run_cmd = ( - f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"' - ) + run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"' if is_v2: run_cmd += f' --is_v2' run_cmd += f' --device {device}' @@ -89,10 +98,11 @@ def extract_lycoris_locon( # if mode == 'threshold': # return gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=False), gr.Row.update(visible=True) + def update_mode(mode): # Create a list of possible mode values modes = ['fixed', 'threshold', 'ratio', 'quantile'] - + # Initialize an empty list to store visibility updates updates = [] @@ -104,12 +114,15 @@ def update_mode(mode): # Return the visibility updates as a tuple return tuple(updates) + def gradio_extract_lycoris_locon_tab(): with gr.Tab('Extract LyCORIS LoCON'): gr.Markdown( 'This utility can extract a LyCORIS LoCon network from a finetuned model.' ) - lora_ext = gr.Textbox(value='*.safetensors', visible=False) # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) + lora_ext = gr.Textbox( + value='*.safetensors', visible=False + ) # lora_ext = gr.Textbox(value='*.safetensors *.pt', visible=False) lora_ext_name = gr.Textbox(value='LoRA model types', visible=False) model_ext = gr.Textbox(value='*.safetensors *.ckpt', visible=False) model_ext_name = gr.Textbox(value='Model types', visible=False) @@ -161,14 +174,17 @@ def gradio_extract_lycoris_locon_tab(): ) device = gr.Dropdown( label='Device', - choices=['cpu', 'cuda',], + choices=[ + 'cpu', + 'cuda', + ], value='cuda', interactive=True, ) is_v2 = gr.Checkbox(label='is v2', value=False, interactive=True) mode = gr.Dropdown( label='Mode', - choices=['fixed', 'threshold','ratio','quantile'], + choices=['fixed', 'threshold', 'ratio', 'quantile'], value='fixed', interactive=True, ) @@ -241,7 +257,9 @@ def gradio_extract_lycoris_locon_tab(): interactive=True, ) with gr.Row(): - use_sparse_bias = gr.Checkbox(label='Use sparse biais', value=False, interactive=True) + use_sparse_bias = gr.Checkbox( + label='Use sparse biais', value=False, interactive=True + ) sparsity = gr.Slider( minimum=0, maximum=1, @@ -250,24 +268,42 @@ def gradio_extract_lycoris_locon_tab(): step=0.01, interactive=True, ) - disable_cp = gr.Checkbox(label='Disable CP decomposition', value=False, interactive=True) + disable_cp = gr.Checkbox( + label='Disable CP decomposition', value=False, interactive=True + ) mode.change( update_mode, inputs=[mode], outputs=[ - fixed, threshold, ratio, quantile, - ] + fixed, + threshold, + ratio, + quantile, + ], ) extract_button = gr.Button('Extract LyCORIS LoCon') extract_button.click( extract_lycoris_locon, - inputs=[db_model, base_model, output_name, device, - is_v2, mode, linear_dim, conv_dim, - linear_threshold, conv_threshold, - linear_ratio, conv_ratio, - linear_quantile, conv_quantile, - use_sparse_bias, sparsity, disable_cp], + inputs=[ + db_model, + base_model, + output_name, + device, + is_v2, + mode, + linear_dim, + conv_dim, + linear_threshold, + conv_threshold, + linear_ratio, + conv_ratio, + linear_quantile, + conv_quantile, + use_sparse_bias, + sparsity, + disable_cp, + ], show_progress=False, ) diff --git a/library/git_caption_gui.py b/library/git_caption_gui.py index 0ce1dda..9aaf3d9 100644 --- a/library/git_caption_gui.py +++ b/library/git_caption_gui.py @@ -27,7 +27,9 @@ def caption_images( return print(f'GIT captioning files in {train_data_dir}...') - run_cmd = f'.\\venv\\Scripts\\python.exe "finetune/make_captions_by_git.py"' + run_cmd = ( + f'.\\venv\\Scripts\\python.exe "finetune/make_captions_by_git.py"' + ) if not model_id == '': run_cmd += f' --model_id="{model_id}"' run_cmd += f' --batch_size="{int(batch_size)}"' diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py new file mode 100644 index 0000000..3e04b88 --- /dev/null +++ b/library/lpw_stable_diffusion.py @@ -0,0 +1,1179 @@ +# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# and modify to support SD2.x + +import inspect +import re +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import SchedulerMixin, StableDiffusionPipeline +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.utils import logging + + +try: + from diffusers.utils import PIL_INTERPOLATION +except ImportError: + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } + else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe: StableDiffusionPipeline, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + if clip_skip is None or clip_skip == 1: + text_embedding = pipe.text_encoder(text_input_chunk)[0] + else: + enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0] + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + text_embeddings = pipe.text_encoder(text_input)[0] + return text_embeddings + + +def get_weighted_text_embeddings( + pipe: StableDiffusionPipeline, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 3, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + pipe (`StableDiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + pad = pipe.tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings + return text_embeddings, None + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing + weighting in prompt. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + clip_skip: int, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker, + ) + self.clip_skip = clip_skip + self.__init__additional__() + + # else: + # def __init__( + # self, + # vae: AutoencoderKL, + # text_encoder: CLIPTextModel, + # tokenizer: CLIPTokenizer, + # unet: UNet2DConditionModel, + # scheduler: SchedulerMixin, + # safety_checker: StableDiffusionSafetyChecker, + # feature_extractor: CLIPFeatureExtractor, + # ): + # super().__init__( + # vae=vae, + # text_encoder=text_encoder, + # tokenizer=tokenizer, + # unet=unet, + # scheduler=scheduler, + # safety_checker=safety_checker, + # feature_extractor=feature_extractor, + # ) + # self.__init__additional__() + + def __init__additional__(self): + if not hasattr(self, "vae_scale_factor"): + setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + ) + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + bs_embed, seq_len, _ = uncond_embeddings.shape + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def check_inputs(self, prompt, height, width, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + print(height, width) + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + def get_timesteps(self, num_inference_steps, strength, device, is_text2img): + if is_text2img: + return self.scheduler.timesteps.to(device), num_inference_steps + else: + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(device) + return timesteps, num_inference_steps - t_start + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype)) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None): + if image is None: + shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, None, None + else: + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + init_latents = torch.cat([init_latents] * batch_size, dim=0) + init_latents_orig = init_latents + shape = init_latents.shape + + # add noise to latents using the timesteps + if device.type == "mps": + noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep) + return latents, init_latents_orig, noise + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + strength: float = 0.8, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, + ) + dtype = text_embeddings.dtype + + # 4. Preprocess image and mask + if isinstance(image, PIL.Image.Image): + image = preprocess_image(image) + if image is not None: + image = image.to(device=self.device, dtype=dtype) + if isinstance(mask_image, PIL.Image.Image): + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=dtype) + mask = torch.cat([mask] * batch_size * num_images_per_prompt) + else: + mask = None + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, init_latents_orig, noise = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + height, + width, + dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return image, has_nsfw_concept + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) + + def img2img( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for image-to-image generation. + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) + + def inpaint( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for inpaint. + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index 527ff67..05a7734 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -30,15 +30,19 @@ def resize_lora( if not os.path.isfile(model): msgbox('The provided model is not a file') return - + if dynamic_method == 'sv_ratio': if float(dynamic_param) < 2: - msgbox(f'Dynamic parameter for {dynamic_method} need to be 2 or greater...') + msgbox( + f'Dynamic parameter for {dynamic_method} need to be 2 or greater...' + ) return - + if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative': if float(dynamic_param) < 0 or float(dynamic_param) > 1: - msgbox(f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...') + msgbox( + f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...' + ) return # Check if save_to end with one of the defines extension. If not add .safetensors. @@ -108,25 +112,18 @@ def gradio_resize_lora_tab(): with gr.Row(): dynamic_method = gr.Dropdown( - choices=['None', - 'sv_ratio', - 'sv_fro', - 'sv_cumulative' - ], + choices=['None', 'sv_ratio', 'sv_fro', 'sv_cumulative'], value='sv_fro', label='Dynamic method', - interactive=True + interactive=True, ) dynamic_param = gr.Textbox( label='Dynamic parameter', value='0.9', interactive=True, - placeholder='Value for the dynamic method selected.' - ) - verbose = gr.Checkbox( - label='Verbose', - value=False + placeholder='Value for the dynamic method selected.', ) + verbose = gr.Checkbox(label='Verbose', value=False) with gr.Row(): save_to = gr.Textbox( label='Save to', @@ -150,7 +147,10 @@ def gradio_resize_lora_tab(): ) device = gr.Dropdown( label='Device', - choices=['cpu', 'cuda',], + choices=[ + 'cpu', + 'cuda', + ], value='cuda', interactive=True, ) diff --git a/library/sampler_gui.py b/library/sampler_gui.py index 7a7734f..ce95313 100644 --- a/library/sampler_gui.py +++ b/library/sampler_gui.py @@ -74,18 +74,18 @@ def run_cmd_sample( sample_prompts, output_dir, ): - output_dir = os.path.join(output_dir, "sample") - + output_dir = os.path.join(output_dir, 'sample') + if not os.path.exists(output_dir): os.makedirs(output_dir) - + run_cmd = '' - + if sample_every_n_epochs == 0 and sample_every_n_steps == 0: return run_cmd # Create the prompt file and get its path - sample_prompts_path = os.path.join(output_dir, "prompt.txt") + sample_prompts_path = os.path.join(output_dir, 'prompt.txt') with open(sample_prompts_path, 'w') as f: f.write(sample_prompts) diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py index b34b503..042be2e 100644 --- a/library/svd_merge_lora_gui.py +++ b/library/svd_merge_lora_gui.py @@ -163,7 +163,10 @@ def gradio_svd_merge_lora_tab(): ) device = gr.Dropdown( label='Device', - choices=['cpu', 'cuda',], + choices=[ + 'cpu', + 'cuda', + ], value='cuda', interactive=True, ) diff --git a/library/train_util.py b/library/train_util.py index 718fe36..7d31182 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1,8 +1,10 @@ # common functions for training import argparse +import ast import importlib import json +import pathlib import re import shutil import time @@ -23,6 +25,7 @@ import random import hashlib import subprocess from io import BytesIO +import toml from tqdm import tqdm import torch @@ -32,10 +35,20 @@ from transformers import CLIPTokenizer import transformers import diffusers from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION -from diffusers import (StableDiffusionPipeline, DDPMScheduler, - EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, - KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler) +from diffusers import ( + StableDiffusionPipeline, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, +) import albumentations as albu import numpy as np from PIL import Image @@ -43,12 +56,12 @@ import cv2 from einops import rearrange from torch import einsum import safetensors.torch - +from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ # checkpointファイル名 EPOCH_STATE_NAME = "{}-{:06d}-state" @@ -64,1086 +77,1241 @@ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] # , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux? -class ImageInfo(): - def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: - self.image_key: str = image_key - self.num_repeats: int = num_repeats - self.caption: str = caption - self.is_reg: bool = is_reg - self.absolute_path: str = absolute_path - self.image_size: Tuple[int, int] = None - self.resized_size: Tuple[int, int] = None - self.bucket_reso: Tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_npz_flipped: str = None +class ImageInfo: + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.caption: str = caption + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: Tuple[int, int] = None + self.resized_size: Tuple[int, int] = None + self.bucket_reso: Tuple[int, int] = None + self.latents: torch.Tensor = None + self.latents_flipped: torch.Tensor = None + self.latents_npz: str = None + self.latents_npz_flipped: str = None -class BucketManager(): - def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: - self.no_upscale = no_upscale - if max_reso is None: - self.max_reso = None - self.max_area = None - else: - self.max_reso = max_reso - self.max_area = max_reso[0] * max_reso[1] - self.min_size = min_size - self.max_size = max_size - self.reso_steps = reso_steps - - self.resos = [] - self.reso_to_id = {} - self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key - - def add_image(self, reso, image): - bucket_id = self.reso_to_id[reso] - self.buckets[bucket_id].append(image) - - def shuffle(self): - for bucket in self.buckets: - random.shuffle(bucket) - - def sort(self): - # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す - sorted_resos = self.resos.copy() - sorted_resos.sort() - - sorted_buckets = [] - sorted_reso_to_id = {} - for i, reso in enumerate(sorted_resos): - bucket_id = self.reso_to_id[reso] - sorted_buckets.append(self.buckets[bucket_id]) - sorted_reso_to_id[reso] = i - - self.resos = sorted_resos - self.buckets = sorted_buckets - self.reso_to_id = sorted_reso_to_id - - def make_buckets(self): - resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) - self.set_predefined_resos(resos) - - def set_predefined_resos(self, resos): - # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく - self.predefined_resos = resos.copy() - self.predefined_resos_set = set(resos) - self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) - - def add_if_new_reso(self, reso): - if reso not in self.reso_to_id: - bucket_id = len(self.resos) - self.reso_to_id[reso] = bucket_id - self.resos.append(reso) - self.buckets.append([]) - # print(reso, bucket_id, len(self.buckets)) - - def round_to_steps(self, x): - x = int(x + .5) - return x - x % self.reso_steps - - def select_bucket(self, image_width, image_height): - aspect_ratio = image_width / image_height - if not self.no_upscale: - # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する - reso = (image_width, image_height) - if reso in self.predefined_resos_set: - pass - else: - ar_errors = self.predefined_aspect_ratios - aspect_ratio - predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの - reso = self.predefined_resos[predefined_bucket_id] - - ar_reso = reso[0] / reso[1] - if aspect_ratio > ar_reso: # 横が長い→縦を合わせる - scale = reso[1] / image_height - else: - scale = reso[0] / image_width - - resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) - # print("use predef", image_width, image_height, reso, resized_size) - else: - if image_width * image_height > self.max_area: - # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める - resized_width = math.sqrt(self.max_area * aspect_ratio) - resized_height = self.max_area / resized_width - assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" - - # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ - # 元のbucketingと同じロジック - b_width_rounded = self.round_to_steps(resized_width) - b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) - ar_width_rounded = b_width_rounded / b_height_in_wr - - b_height_rounded = self.round_to_steps(resized_height) - b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) - ar_height_rounded = b_width_in_hr / b_height_rounded - - # print(b_width_rounded, b_height_in_wr, ar_width_rounded) - # print(b_width_in_hr, b_height_rounded, ar_height_rounded) - - if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): - resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5)) +class BucketManager: + def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: + self.no_upscale = no_upscale + if max_reso is None: + self.max_reso = None + self.max_area = None else: - resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded) - # print(resized_size) - else: - resized_size = (image_width, image_height) # リサイズは不要 + self.max_reso = max_reso + self.max_area = max_reso[0] * max_reso[1] + self.min_size = min_size + self.max_size = max_size + self.reso_steps = reso_steps - # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) - bucket_width = resized_size[0] - resized_size[0] % self.reso_steps - bucket_height = resized_size[1] - resized_size[1] % self.reso_steps - # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + self.resos = [] + self.reso_to_id = {} + self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key - reso = (bucket_width, bucket_height) + def add_image(self, reso, image): + bucket_id = self.reso_to_id[reso] + self.buckets[bucket_id].append(image) - self.add_if_new_reso(reso) + def shuffle(self): + for bucket in self.buckets: + random.shuffle(bucket) - ar_error = (reso[0] / reso[1]) - aspect_ratio - return reso, resized_size, ar_error + def sort(self): + # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す + sorted_resos = self.resos.copy() + sorted_resos.sort() + + sorted_buckets = [] + sorted_reso_to_id = {} + for i, reso in enumerate(sorted_resos): + bucket_id = self.reso_to_id[reso] + sorted_buckets.append(self.buckets[bucket_id]) + sorted_reso_to_id[reso] = i + + self.resos = sorted_resos + self.buckets = sorted_buckets + self.reso_to_id = sorted_reso_to_id + + def make_buckets(self): + resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) + self.set_predefined_resos(resos) + + def set_predefined_resos(self, resos): + # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく + self.predefined_resos = resos.copy() + self.predefined_resos_set = set(resos) + self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) + + def add_if_new_reso(self, reso): + if reso not in self.reso_to_id: + bucket_id = len(self.resos) + self.reso_to_id[reso] = bucket_id + self.resos.append(reso) + self.buckets.append([]) + # print(reso, bucket_id, len(self.buckets)) + + def round_to_steps(self, x): + x = int(x + 0.5) + return x - x % self.reso_steps + + def select_bucket(self, image_width, image_height): + aspect_ratio = image_width / image_height + if not self.no_upscale: + # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する + reso = (image_width, image_height) + if reso in self.predefined_resos_set: + pass + else: + ar_errors = self.predefined_aspect_ratios - aspect_ratio + predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの + reso = self.predefined_resos[predefined_bucket_id] + + ar_reso = reso[0] / reso[1] + if aspect_ratio > ar_reso: # 横が長い→縦を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + + resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) + # print("use predef", image_width, image_height, reso, resized_size) + else: + if image_width * image_height > self.max_area: + # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める + resized_width = math.sqrt(self.max_area * aspect_ratio) + resized_height = self.max_area / resized_width + assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" + + # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ + # 元のbucketingと同じロジック + b_width_rounded = self.round_to_steps(resized_width) + b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) + ar_width_rounded = b_width_rounded / b_height_in_wr + + b_height_rounded = self.round_to_steps(resized_height) + b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) + ar_height_rounded = b_width_in_hr / b_height_rounded + + # print(b_width_rounded, b_height_in_wr, ar_width_rounded) + # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + + if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): + resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) + else: + resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) + # print(resized_size) + else: + resized_size = (image_width, image_height) # リサイズは不要 + + # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) + bucket_width = resized_size[0] - resized_size[0] % self.reso_steps + bucket_height = resized_size[1] - resized_size[1] % self.reso_steps + # print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height) + + reso = (bucket_width, bucket_height) + + self.add_if_new_reso(reso) + + ar_error = (reso[0] / reso[1]) - aspect_ratio + return reso, resized_size, ar_error class BucketBatchIndex(NamedTuple): - bucket_index: int - bucket_batch_size: int - batch_index: int + bucket_index: int + bucket_batch_size: int + batch_index: int class AugHelper: - def __init__(self): - # prepare all possible augmentators - color_aug_method = albu.OneOf([ - albu.HueSaturationValue(8, 0, 0, p=.5), - albu.RandomGamma((95, 105), p=.5), - ], p=.33) - flip_aug_method = albu.HorizontalFlip(p=0.5) + def __init__(self): + # prepare all possible augmentators + color_aug_method = albu.OneOf( + [ + albu.HueSaturationValue(8, 0, 0, p=0.5), + albu.RandomGamma((95, 105), p=0.5), + ], + p=0.33, + ) + flip_aug_method = albu.HorizontalFlip(p=0.5) - # key: (use_color_aug, use_flip_aug) - self.augmentors = { - (True, True): albu.Compose([ - color_aug_method, - flip_aug_method, - ], p=1.), - (True, False): albu.Compose([ - color_aug_method, - ], p=1.), - (False, True): albu.Compose([ - flip_aug_method, - ], p=1.), - (False, False): None - } + # key: (use_color_aug, use_flip_aug) + self.augmentors = { + (True, True): albu.Compose( + [ + color_aug_method, + flip_aug_method, + ], + p=1.0, + ), + (True, False): albu.Compose( + [ + color_aug_method, + ], + p=1.0, + ), + (False, True): albu.Compose( + [ + flip_aug_method, + ], + p=1.0, + ), + (False, False): None, + } - def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: - return self.augmentors[(use_color_aug, use_flip_aug)] + def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: + return self.augmentors[(use_color_aug, use_flip_aug)] class BaseSubset: - def __init__(self, image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, keep_tokens: int, color_aug: bool, flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float) -> None: - self.image_dir = image_dir - self.num_repeats = num_repeats - self.shuffle_caption = shuffle_caption - self.keep_tokens = keep_tokens - self.color_aug = color_aug - self.flip_aug = flip_aug - self.face_crop_aug_range = face_crop_aug_range - self.random_crop = random_crop - self.caption_dropout_rate = caption_dropout_rate - self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs - self.caption_tag_dropout_rate = caption_tag_dropout_rate + def __init__( + self, + image_dir: Optional[str], + num_repeats: int, + shuffle_caption: bool, + keep_tokens: int, + color_aug: bool, + flip_aug: bool, + face_crop_aug_range: Optional[Tuple[float, float]], + random_crop: bool, + caption_dropout_rate: float, + caption_dropout_every_n_epochs: int, + caption_tag_dropout_rate: float, + ) -> None: + self.image_dir = image_dir + self.num_repeats = num_repeats + self.shuffle_caption = shuffle_caption + self.keep_tokens = keep_tokens + self.color_aug = color_aug + self.flip_aug = flip_aug + self.face_crop_aug_range = face_crop_aug_range + self.random_crop = random_crop + self.caption_dropout_rate = caption_dropout_rate + self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs + self.caption_tag_dropout_rate = caption_tag_dropout_rate - self.img_count = 0 + self.img_count = 0 class DreamBoothSubset(BaseSubset): - def __init__(self, image_dir: str, is_reg: bool, class_tokens: Optional[str], caption_extension: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: - assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" + def __init__( + self, + image_dir: str, + is_reg: bool, + class_tokens: Optional[str], + caption_extension: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + ) -> None: + assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" - super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, - face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + ) - self.is_reg = is_reg - self.class_tokens = class_tokens - self.caption_extension = caption_extension + self.is_reg = is_reg + self.class_tokens = class_tokens + self.caption_extension = caption_extension - def __eq__(self, other) -> bool: - if not isinstance(other, DreamBoothSubset): - return NotImplemented - return self.image_dir == other.image_dir + def __eq__(self, other) -> bool: + if not isinstance(other, DreamBoothSubset): + return NotImplemented + return self.image_dir == other.image_dir class FineTuningSubset(BaseSubset): - def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: - assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" + def __init__( + self, + image_dir, + metadata_file: str, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + ) -> None: + assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" - super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, - face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + super().__init__( + image_dir, + num_repeats, + shuffle_caption, + keep_tokens, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + ) - self.metadata_file = metadata_file + self.metadata_file = metadata_file - def __eq__(self, other) -> bool: - if not isinstance(other, FineTuningSubset): - return NotImplemented - return self.metadata_file == other.metadata_file + def __eq__(self, other) -> bool: + if not isinstance(other, FineTuningSubset): + return NotImplemented + return self.metadata_file == other.metadata_file class BaseDataset(torch.utils.data.Dataset): - def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None: - super().__init__() - self.tokenizer = tokenizer - self.max_token_length = max_token_length - # width/height is used when enable_bucket==False - self.width, self.height = (None, None) if resolution is None else resolution - self.debug_dataset = debug_dataset - - self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] - - self.token_padding_disabled = False - self.tag_frequency = {} - - self.enable_bucket = False - self.bucket_manager: BucketManager = None # not initialized - self.min_bucket_reso = None - self.max_bucket_reso = None - self.bucket_reso_steps = None - self.bucket_no_upscale = None - self.bucket_info = None # for metadata - - self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 - - self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ - - # augmentation - self.aug_helper = AugHelper() - - self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) - - self.image_data: Dict[str, ImageInfo] = {} - self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} - - self.replacements = {} - - def set_current_epoch(self, epoch): - self.current_epoch = epoch - self.shuffle_buckets() - - def set_tag_frequency(self, dir_name, captions): - frequency_for_dir = self.tag_frequency.get(dir_name, {}) - self.tag_frequency[dir_name] = frequency_for_dir - for caption in captions: - for tag in caption.split(","): - tag = tag.strip() - if tag: - tag = tag.lower() - frequency = frequency_for_dir.get(tag, 0) - frequency_for_dir[tag] = frequency + 1 - - def disable_token_padding(self): - self.token_padding_disabled = True - - def add_replacement(self, str_from, str_to): - self.replacements[str_from] = str_to - - def process_caption(self, subset: BaseSubset, caption): - # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い - is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate - is_drop_out = is_drop_out or subset.caption_dropout_every_n_epochs > 0 and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 - - if is_drop_out: - caption = "" - else: - if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: - def dropout_tags(tokens): - if subset.caption_tag_dropout_rate <= 0: - return tokens - l = [] - for token in tokens: - if random.random() >= subset.caption_tag_dropout_rate: - l.append(token) - return l - - fixed_tokens = [] - flex_tokens = [t.strip() for t in caption.strip().split(",")] - if subset.keep_tokens > 0: - fixed_tokens = flex_tokens[:subset.keep_tokens] - flex_tokens = flex_tokens[subset.keep_tokens:] - - if subset.shuffle_caption: - random.shuffle(flex_tokens) - - flex_tokens = dropout_tags(flex_tokens) - - caption = ", ".join(fixed_tokens + flex_tokens) - - # textual inversion対応 - for str_from, str_to in self.replacements.items(): - if str_from == "": - # replace all - if type(str_to) == list: - caption = random.choice(str_to) - else: - caption = str_to - else: - caption = caption.replace(str_from, str_to) - - return caption - - def get_input_ids(self, caption): - input_ids = self.tokenizer(caption, padding="max_length", truncation=True, - max_length=self.tokenizer_max_length, return_tensors="pt").input_ids - - if self.tokenizer_max_length > self.tokenizer.model_max_length: - input_ids = input_ids.squeeze(0) - iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: - # v1 - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) - ids_chunk = (input_ids[0].unsqueeze(0), - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) - ids_chunk = torch.cat(ids_chunk) - iids_list.append(ids_chunk) - else: - # v2 - # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): - ids_chunk = (input_ids[0].unsqueeze(0), # BOS - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) # PAD or EOS - ids_chunk = torch.cat(ids_chunk) - - # 末尾が または の場合は、何もしなくてよい - # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: - ids_chunk[-1] = self.tokenizer.eos_token_id - # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id - - iids_list.append(ids_chunk) - - input_ids = torch.stack(iids_list) # 3,77 - return input_ids - - def register_image(self, info: ImageInfo, subset: BaseSubset): - self.image_data[info.image_key] = info - self.image_to_subset[info.image_key] = subset - - def make_buckets(self): - ''' - bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) - min_size and max_size are ignored when enable_bucket is False - ''' - print("loading image sizes.") - for info in tqdm(self.image_data.values()): - if info.image_size is None: - info.image_size = self.get_image_size(info.absolute_path) - - if self.enable_bucket: - print("make buckets") - else: - print("prepare dataset") - - # bucketを作成し、画像をbucketに振り分ける - if self.enable_bucket: - if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み - self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height), - self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps) - if not self.bucket_no_upscale: - self.bucket_manager.make_buckets() - else: - print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます") - - img_ar_errors = [] - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height) - - # print(image_info.image_key, image_info.bucket_reso) - img_ar_errors.append(abs(ar_error)) - - self.bucket_manager.sort() - else: - self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) - self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) - - for image_info in self.image_data.values(): - for _ in range(image_info.num_repeats): - self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) - - # bucket情報を表示、格納する - if self.enable_bucket: - self.bucket_info = {"buckets": {}} - print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") - for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): - count = len(bucket) - if count > 0: - self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} - print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") - - img_ar_errors = np.array(img_ar_errors) - mean_img_ar_error = np.mean(np.abs(img_ar_errors)) - self.bucket_info["mean_img_ar_error"] = mean_img_ar_error - print(f"mean ar error (without repeats): {mean_img_ar_error}") - - # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] - for bucket_index, bucket in enumerate(self.bucket_manager.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) - - # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す - #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる - # - # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは - # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう - # # そのためバッチサイズを画像種類までに制限する - # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? - # # TO DO 正則化画像をepochまたがりで利用する仕組み - # num_of_image_types = len(set(bucket)) - # bucket_batch_size = min(self.batch_size, num_of_image_types) - # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) - # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) - # for batch_index in range(batch_count): - # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) - # ↑ここまで - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - self.bucket_manager.shuffle() - - def load_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img - - def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): - image_height, image_width = image.shape[0:2] - - if image_width != resized_size[0] or image_height != resized_size[1]: - # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - - image_height, image_width = image.shape[0:2] - if image_width > reso[0]: - trim_size = image_width - reso[0] - p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) - # print("w", trim_size, p) - image = image[:, p:p + reso[0]] - if image_height > reso[1]: - trim_size = image_height - reso[1] - p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) - # print("h", trim_size, p) - image = image[p:p + reso[1]] - - assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image - - def is_latent_cacheable(self): - return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) - - def cache_latents(self, vae): - # TODO ここを高速化したい - print("caching latents.") - for info in tqdm(self.image_data.values()): - subset = self.image_to_subset[info.image_key] - - if info.latents_npz is not None: - info.latents = self.load_latents_from_npz(info, False) - info.latents = torch.FloatTensor(info.latents) - info.latents_flipped = self.load_latents_from_npz(info, True) # might be None - if info.latents_flipped is not None: - info.latents_flipped = torch.FloatTensor(info.latents_flipped) - continue - - image = self.load_image(info.absolute_path) - image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) - - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - - if subset.flip_aug: - image = image[:, ::-1].copy() # cannot convert to Tensor without copy - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - - def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size - - def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = self.load_image(image_path) - - face_cx = face_cy = face_w = face_h = 0 - if subset.face_crop_aug_range is not None: - tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') - if len(tokens) >= 5: - face_cx = int(tokens[-4]) - face_cy = int(tokens[-3]) - face_w = int(tokens[-2]) - face_h = int(tokens[-1]) - - return img, face_cx, face_cy, face_w, face_h - - # いい感じに切り出す - def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): - height, width = image.shape[0:2] - if height == self.height and width == self.width: - return image - - # 画像サイズはsizeより大きいのでリサイズする - face_size = max(face_w, face_h) - min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ - max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ - if min_scale >= max_scale: # range指定がmin==max - scale = min_scale - else: - scale = random.uniform(min_scale, max_scale) - - nh = int(height * scale + .5) - nw = int(width * scale + .5) - assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) - face_cx = int(face_cx * scale + .5) - face_cy = int(face_cy * scale + .5) - height, width = nh, nw - - # 顔を中心として448*640とかへ切り出す - for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): - p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 - - if subset.random_crop: - # 背景も含めるために顔を中心に置く確率を高めつつずらす - range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう - p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 - else: - # range指定があるときのみ、すこしだけランダムに(わりと適当) - if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: - if face_size > self.size // 10 and face_size >= 40: - p1 = p1 + random.randint(-face_size // 20, +face_size // 20) - - p1 = max(0, min(p1, length - target_size)) - - if axis == 0: - image = image[p1:p1 + target_size, :] - else: - image = image[:, p1:p1 + target_size] - - return image - - def load_latents_from_npz(self, image_info: ImageInfo, flipped): - npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz - if npz_file is None: - return None - return np.load(npz_file)['arr_0'] - - def __len__(self): - return self._length - - def __getitem__(self, index): - bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] - bucket_batch_size = self.buckets_indices[index].bucket_batch_size - image_index = self.buckets_indices[index].batch_index * bucket_batch_size - - loss_weights = [] - captions = [] - input_ids_list = [] - latents_list = [] - images = [] - - for image_key in bucket[image_index:image_index + bucket_batch_size]: - image_info = self.image_data[image_key] - subset = self.image_to_subset[image_key] - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) - - # image/latentsを処理する - if image_info.latents is not None: - latents = image_info.latents if not subset.flip_aug or random.random() < .5 else image_info.latents_flipped - image = None - elif image_info.latents_npz is not None: - latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= .5) - latents = torch.FloatTensor(latents) - image = None - else: - # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) - else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert subset.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p:p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p:p + self.width] - - im_h, im_w = img.shape[0:2] - assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + def __init__( + self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.max_token_length = max_token_length + # width/height is used when enable_bucket==False + self.width, self.height = (None, None) if resolution is None else resolution + self.debug_dataset = debug_dataset + + self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] + + self.token_padding_disabled = False + self.tag_frequency = {} + + self.enable_bucket = False + self.bucket_manager: BucketManager = None # not initialized + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None + self.bucket_no_upscale = None + self.bucket_info = None # for metadata + + self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + + self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ # augmentation - aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) - if aug is not None: - img = aug(image=img)['image'] + self.aug_helper = AugHelper() - latents = None - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + self.image_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) - images.append(image) - latents_list.append(latents) + self.image_data: Dict[str, ImageInfo] = {} + self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} - caption = self.process_caption(subset, image_info.caption) - captions.append(caption) - if not self.token_padding_disabled: # this option might be omitted in future - input_ids_list.append(self.get_input_ids(caption)) + self.replacements = {} - example = {} - example['loss_weights'] = torch.FloatTensor(loss_weights) + def set_current_epoch(self, epoch): + self.current_epoch = epoch + self.shuffle_buckets() - if self.token_padding_disabled: - # padding=True means pad in the batch - example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids - else: - # batch processing seems to be good - example['input_ids'] = torch.stack(input_ids_list) + def set_tag_frequency(self, dir_name, captions): + frequency_for_dir = self.tag_frequency.get(dir_name, {}) + self.tag_frequency[dir_name] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + tag = tag.strip() + if tag: + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 - if images[0] is not None: - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - else: - images = None - example['images'] = images + def disable_token_padding(self): + self.token_padding_disabled = True - example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None + def add_replacement(self, str_from, str_to): + self.replacements[str_from] = str_to - if self.debug_dataset: - example['image_keys'] = bucket[image_index:image_index + self.batch_size] - example['captions'] = captions - return example + def process_caption(self, subset: BaseSubset, caption): + # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い + is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate + is_drop_out = ( + is_drop_out + or subset.caption_dropout_every_n_epochs > 0 + and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 + ) + + if is_drop_out: + caption = "" + else: + if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: + + def dropout_tags(tokens): + if subset.caption_tag_dropout_rate <= 0: + return tokens + l = [] + for token in tokens: + if random.random() >= subset.caption_tag_dropout_rate: + l.append(token) + return l + + fixed_tokens = [] + flex_tokens = [t.strip() for t in caption.strip().split(",")] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[: subset.keep_tokens] + flex_tokens = flex_tokens[subset.keep_tokens :] + + if subset.shuffle_caption: + random.shuffle(flex_tokens) + + flex_tokens = dropout_tags(flex_tokens) + + caption = ", ".join(fixed_tokens + flex_tokens) + + # textual inversion対応 + for str_from, str_to in self.replacements.items(): + if str_from == "": + # replace all + if type(str_to) == list: + caption = random.choice(str_to) + else: + caption = str_to + else: + caption = caption.replace(str_from, str_to) + + return caption + + def get_input_ids(self, caption): + input_ids = self.tokenizer( + caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt" + ).input_ids + + if self.tokenizer_max_length > self.tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range( + 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 + ): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range( + 1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2 + ): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + def register_image(self, info: ImageInfo, subset: BaseSubset): + self.image_data[info.image_key] = info + self.image_to_subset[info.image_key] = subset + + def make_buckets(self): + """ + bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) + min_size and max_size are ignored when enable_bucket is False + """ + print("loading image sizes.") + for info in tqdm(self.image_data.values()): + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) + + if self.enable_bucket: + print("make buckets") + else: + print("prepare dataset") + + # bucketを作成し、画像をbucketに振り分ける + if self.enable_bucket: + if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み + self.bucket_manager = BucketManager( + self.bucket_no_upscale, + (self.width, self.height), + self.min_bucket_reso, + self.max_bucket_reso, + self.bucket_reso_steps, + ) + if not self.bucket_no_upscale: + self.bucket_manager.make_buckets() + else: + print( + "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" + ) + + img_ar_errors = [] + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( + image_width, image_height + ) + + # print(image_info.image_key, image_info.bucket_reso) + img_ar_errors.append(abs(ar_error)) + + self.bucket_manager.sort() + else: + self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) + self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) + + for image_info in self.image_data.values(): + for _ in range(image_info.num_repeats): + self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) + + # bucket情報を表示、格納する + if self.enable_bucket: + self.bucket_info = {"buckets": {}} + print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): + count = len(bucket) + if count > 0: + self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} + print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + + img_ar_errors = np.array(img_ar_errors) + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + self.bucket_info["mean_img_ar_error"] = mean_img_ar_error + print(f"mean ar error (without repeats): {mean_img_ar_error}") + + # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる + self.buckets_indices: List(BucketBatchIndex) = [] + for bucket_index, bucket in enumerate(self.bucket_manager.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) + + # ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す + #  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる + # + # # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは + # # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう + # # そのためバッチサイズを画像種類までに制限する + # # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない? + # # TO DO 正則化画像をepochまたがりで利用する仕組み + # num_of_image_types = len(set(bucket)) + # bucket_batch_size = min(self.batch_size, num_of_image_types) + # batch_count = int(math.ceil(len(bucket) / bucket_batch_size)) + # # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count) + # for batch_index in range(batch_count): + # self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) + # ↑ここまで + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + def shuffle_buckets(self): + random.shuffle(self.buckets_indices) + self.bucket_manager.shuffle() + + def load_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): + image_height, image_width = image.shape[0:2] + + if image_width != resized_size[0] or image_height != resized_size[1]: + # リサイズする + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + + image_height, image_width = image.shape[0:2] + if image_width > reso[0]: + trim_size = image_width - reso[0] + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) + # print("w", trim_size, p) + image = image[:, p : p + reso[0]] + if image_height > reso[1]: + trim_size = image_height - reso[1] + p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) + # print("h", trim_size, p) + image = image[p : p + reso[1]] + + assert ( + image.shape[0] == reso[1] and image.shape[1] == reso[0] + ), f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image + + def is_latent_cacheable(self): + return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) + + def cache_latents(self, vae): + # TODO ここを高速化したい + print("caching latents.") + for info in tqdm(self.image_data.values()): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: + info.latents = self.load_latents_from_npz(info, False) + info.latents = torch.FloatTensor(info.latents) + info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + if info.latents_flipped is not None: + info.latents_flipped = torch.FloatTensor(info.latents_flipped) + continue + + image = self.load_image(info.absolute_path) + image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) + + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + if subset.flip_aug: + image = image[:, ::-1].copy() # cannot convert to Tensor without copy + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + def get_image_size(self, image_path): + image = Image.open(image_path) + return image.size + + def load_image_with_face_info(self, subset: BaseSubset, image_path: str): + img = self.load_image(image_path) + + face_cx = face_cy = face_w = face_h = 0 + if subset.face_crop_aug_range is not None: + tokens = os.path.splitext(os.path.basename(image_path))[0].split("_") + if len(tokens) >= 5: + face_cx = int(tokens[-4]) + face_cy = int(tokens[-3]) + face_w = int(tokens[-2]) + face_h = int(tokens[-1]) + + return img, face_cx, face_cy, face_w, face_h + + # いい感じに切り出す + def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): + height, width = image.shape[0:2] + if height == self.height and width == self.width: + return image + + # 画像サイズはsizeより大きいのでリサイズする + face_size = max(face_w, face_h) + min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) + min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ + if min_scale >= max_scale: # range指定がmin==max + scale = min_scale + else: + scale = random.uniform(min_scale, max_scale) + + nh = int(height * scale + 0.5) + nw = int(width * scale + 0.5) + assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + face_cx = int(face_cx * scale + 0.5) + face_cy = int(face_cy * scale + 0.5) + height, width = nh, nw + + # 顔を中心として448*640とかへ切り出す + for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): + p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 + + if subset.random_crop: + # 背景も含めるために顔を中心に置く確率を高めつつずらす + range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう + p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 + else: + # range指定があるときのみ、すこしだけランダムに(わりと適当) + if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: + if face_size > self.size // 10 and face_size >= 40: + p1 = p1 + random.randint(-face_size // 20, +face_size // 20) + + p1 = max(0, min(p1, length - target_size)) + + if axis == 0: + image = image[p1 : p1 + target_size, :] + else: + image = image[:, p1 : p1 + target_size] + + return image + + def load_latents_from_npz(self, image_info: ImageInfo, flipped): + npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + if npz_file is None: + return None + return np.load(npz_file)["arr_0"] + + def __len__(self): + return self._length + + def __getitem__(self, index): + bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] + bucket_batch_size = self.buckets_indices[index].bucket_batch_size + image_index = self.buckets_indices[index].batch_index * bucket_batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + + for image_key in bucket[image_index : image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + + # image/latentsを処理する + if image_info.latents is not None: + latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: + latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) + latents = torch.FloatTensor(latents) + image = None + else: + # 画像を読み込み、必要ならcropする + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert ( + subset.random_crop + ), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p : p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p : p + self.width] + + im_h, im_w = img.shape[0:2] + assert ( + im_h == self.height and im_w == self.width + ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + + # augmentation + aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug) + if aug is not None: + img = aug(image=img)["image"] + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(subset, image_info.caption) + captions.append(caption) + if not self.token_padding_disabled: # this option might be omitted in future + input_ids_list.append(self.get_input_ids(caption)) + + example = {} + example["loss_weights"] = torch.FloatTensor(loss_weights) + + if self.token_padding_disabled: + # padding=True means pad in the batch + example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + else: + # batch processing seems to be good + example["input_ids"] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example["images"] = images + + example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None + + if self.debug_dataset: + example["image_keys"] = bucket[image_index : image_index + self.batch_size] + example["captions"] = captions + return example class DreamBoothDataset(BaseDataset): - def __init__(self, subsets: Sequence[DreamBoothSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, debug_dataset) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + def __init__( + self, + subsets: Sequence[DreamBoothSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + prior_loss_weight: float, + debug_dataset, + ) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) - assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" - self.batch_size = batch_size - self.size = min(self.width, self.height) # 短いほう - self.prior_loss_weight = prior_loss_weight - self.latents_cache = None + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.prior_loss_weight = prior_loss_weight + self.latents_cache = None - self.enable_bucket = enable_bucket - if self.enable_bucket: - assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" - self.min_bucket_reso = min_bucket_reso - self.max_bucket_reso = max_bucket_reso - self.bucket_reso_steps = bucket_reso_steps - self.bucket_no_upscale = bucket_no_upscale - else: - self.min_bucket_reso = None - self.max_bucket_reso = None - self.bucket_reso_steps = None # この情報は使われない - self.bucket_no_upscale = False - - def read_caption(img_path, caption_extension): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split("_") - if len(tokens) >= 5: - base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] - - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding='utf-8') as f: - try: - lines = f.readlines() - except UnicodeDecodeError as e: - print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") - raise e - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() - break - return caption - - def load_dreambooth_dir(subset: DreamBoothSubset): - if not os.path.isdir(subset.image_dir): - print(f"not directory: {subset.image_dir}") - return [], [] - - img_paths = glob_images(subset.image_dir, "*") - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path, subset.caption_extension) - if cap_for_img is None and subset.class_tokens is None: - print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") - captions.append("") + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale else: - captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None # この情報は使われない + self.bucket_no_upscale = False - self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 + def read_caption(img_path, caption_extension): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] - return img_paths, captions + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + caption = lines[0].strip() + break + return caption - print("prepare images.") - num_train_images = 0 - num_reg_images = 0 - reg_infos: List[ImageInfo] = [] - for subset in subsets: - if subset.num_repeats < 1: - print( - f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") - continue + def load_dreambooth_dir(subset: DreamBoothSubset): + if not os.path.isdir(subset.image_dir): + print(f"not directory: {subset.image_dir}") + return [], [] - if subset in self.subsets: - print( - f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") - continue + img_paths = glob_images(subset.image_dir, "*") + print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") - img_paths, captions = load_dreambooth_dir(subset) - if len(img_paths) < 1: - print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") - continue + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path, subset.caption_extension) + if cap_for_img is None and subset.class_tokens is None: + print(f"neither caption file nor class tokens are found. use empty caption for {img_path}") + captions.append("") + else: + captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) - if subset.is_reg: - num_reg_images += subset.num_repeats * len(img_paths) - else: - num_train_images += subset.num_repeats * len(img_paths) + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 - for img_path, caption in zip(img_paths, captions): - info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) - if subset.is_reg: - reg_infos.append(info) + return img_paths, captions + + print("prepare images.") + num_train_images = 0 + num_reg_images = 0 + reg_infos: List[ImageInfo] = [] + for subset in subsets: + if subset.num_repeats < 1: + print( + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + ) + continue + + if subset in self.subsets: + print( + f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" + ) + continue + + img_paths, captions = load_dreambooth_dir(subset) + if len(img_paths) < 1: + print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します") + continue + + if subset.is_reg: + num_reg_images += subset.num_repeats * len(img_paths) + else: + num_train_images += subset.num_repeats * len(img_paths) + + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if subset.is_reg: + reg_infos.append(info) + else: + self.register_image(info, subset) + + subset.img_count = len(img_paths) + self.subsets.append(subset) + + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + print(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + + if num_reg_images == 0: + print("no regularization images / 正則化画像が見つかりませんでした") else: - self.register_image(info, subset) + # num_repeatsを計算する:どうせ大した数ではないのでループで処理する + n = 0 + first_loop = True + while n < num_train_images: + for info in reg_infos: + if first_loop: + self.register_image(info, subset) + n += info.num_repeats + else: + info.num_repeats += 1 + n += 1 + if n >= num_train_images: + break + first_loop = False - subset.img_count = len(img_paths) - self.subsets.append(subset) - - print(f"{num_train_images} train images with repeating.") - self.num_train_images = num_train_images - - print(f"{num_reg_images} reg images.") - if num_train_images < num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - - if num_reg_images == 0: - print("no regularization images / 正則化画像が見つかりませんでした") - else: - # num_repeatsを計算する:どうせ大した数ではないのでループで処理する - n = 0 - first_loop = True - while n < num_train_images: - for info in reg_infos: - if first_loop: - self.register_image(info, subset) - n += info.num_repeats - else: - info.num_repeats += 1 - n += 1 - if n >= num_train_images: - break - first_loop = False - - self.num_reg_images = num_reg_images + self.num_reg_images = num_reg_images class FineTuningDataset(BaseDataset): - def __init__(self, subsets: Sequence[FineTuningSubset], batch_size: int, tokenizer, max_token_length, resolution, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + def __init__( + self, + subsets: Sequence[FineTuningSubset], + batch_size: int, + tokenizer, + max_token_length, + resolution, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + debug_dataset, + ) -> None: + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) - self.batch_size = batch_size + self.batch_size = batch_size - self.num_train_images = 0 - self.num_reg_images = 0 + self.num_train_images = 0 + self.num_reg_images = 0 - for subset in subsets: - if subset.num_repeats < 1: - print( - f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") - continue + for subset in subsets: + if subset.num_repeats < 1: + print( + f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + ) + continue - if subset in self.subsets: - print( - f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") - continue + if subset in self.subsets: + print( + f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" + ) + continue - # メタデータを読み込む - if os.path.exists(subset.metadata_file): - print(f"loading existing metadata: {subset.metadata_file}") - with open(subset.metadata_file, "rt", encoding='utf-8') as f: - metadata = json.load(f) - else: - raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") + # メタデータを読み込む + if os.path.exists(subset.metadata_file): + print(f"loading existing metadata: {subset.metadata_file}") + with open(subset.metadata_file, "rt", encoding="utf-8") as f: + metadata = json.load(f) + else: + raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") - if len(metadata) < 1: - print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") - continue + if len(metadata) < 1: + print(f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します") + continue - tags_list = [] - for image_key, img_md in metadata.items(): - # path情報を作る - if os.path.exists(image_key): - abs_path = image_key + tags_list = [] + for image_key, img_md in metadata.items(): + # path情報を作る + if os.path.exists(image_key): + abs_path = image_key + else: + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + else: + # わりといい加減だがいい方法が思いつかん + abs_path = glob_images(subset.image_dir, image_key) + assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" + abs_path = abs_path[0] + + caption = img_md.get("caption") + tags = img_md.get("tags") + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ", " + tags + tags_list.append(tags) + + if caption is None: + caption = "" + + image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) + image_info.image_size = img_md.get("train_resolution") + + if not subset.color_aug and not subset.random_crop: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + + self.register_image(image_info, subset) + + self.num_train_images += len(metadata) * subset.num_repeats + + # TODO do not record tag freq when no tag + self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) + subset.img_count = len(metadata) + self.subsets.append(subset) + + # check existence of all npz files + use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) + if use_npz_latents: + flip_aug_in_subset = False + npz_any = False + npz_all = True + + for image_info in self.image_data.values(): + subset = self.image_to_subset[image_info.image_key] + + has_npz = image_info.latents_npz is not None + npz_any = npz_any or has_npz + + if subset.flip_aug: + has_npz = has_npz and image_info.latents_npz_flipped is not None + flip_aug_in_subset = True + npz_all = npz_all and has_npz + + if npz_any and not npz_all: + break + + if not npz_any: + use_npz_latents = False + print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") + elif not npz_all: + use_npz_latents = False + print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + if flip_aug_in_subset: + print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") + # else: + # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") + + # check min/max bucket size + sizes = set() + resos = set() + for image_info in self.image_data.values(): + if image_info.image_size is None: + sizes = None # not calculated + break + sizes.add(image_info.image_size[0]) + sizes.add(image_info.image_size[1]) + resos.add(tuple(image_info.image_size)) + + if sizes is None: + if use_npz_latents: + use_npz_latents = False + print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") + + assert ( + resolution is not None + ), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" + + self.enable_bucket = enable_bucket + if self.enable_bucket: + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale else: - npz_path = os.path.join(subset.image_dir, image_key + ".npz") - if os.path.exists(npz_path): - abs_path = npz_path - else: - # わりといい加減だがいい方法が思いつかん - abs_path = glob_images(subset.image_dir, image_key) - assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}" - abs_path = abs_path[0] + if not enable_bucket: + print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + print("using bucket info in metadata / メタデータ内のbucket情報を使います") + self.enable_bucket = True - caption = img_md.get('caption') - tags = img_md.get('tags') - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ', ' + tags - tags_list.append(tags) + assert ( + not bucket_no_upscale + ), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません" - if caption is None: - caption = "" + # bucket情報を初期化しておく、make_bucketsで再作成しない + self.bucket_manager = BucketManager(False, None, None, None, None) + self.bucket_manager.set_predefined_resos(resos) - image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) - image_info.image_size = img_md.get('train_resolution') + # npz情報をきれいにしておく + if not use_npz_latents: + for image_info in self.image_data.values(): + image_info.latents_npz = image_info.latents_npz_flipped = None - if not subset.color_aug and not subset.random_crop: - # if npz exists, use them - image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): + base_name = os.path.splitext(image_key)[0] + npz_file_norm = base_name + ".npz" - self.register_image(image_info, subset) + if os.path.exists(npz_file_norm): + # image_key is full path + npz_file_flip = base_name + "_flip.npz" + if not os.path.exists(npz_file_flip): + npz_file_flip = None + return npz_file_norm, npz_file_flip - self.num_train_images += len(metadata) * subset.num_repeats + # image_key is relative path + npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") + npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") - # TODO do not record tag freq when no tag - self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) - subset.img_count = len(metadata) - self.subsets.append(subset) + if not os.path.exists(npz_file_norm): + npz_file_norm = None + npz_file_flip = None + elif not os.path.exists(npz_file_flip): + npz_file_flip = None - # check existence of all npz files - use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) - if use_npz_latents: - flip_aug_in_subset = False - npz_any = False - npz_all = True - - for image_info in self.image_data.values(): - subset = self.image_to_subset[image_info.image_key] - - has_npz = image_info.latents_npz is not None - npz_any = npz_any or has_npz - - if subset.flip_aug: - has_npz = has_npz and image_info.latents_npz_flipped is not None - flip_aug_in_subset = True - npz_all = npz_all and has_npz - - if npz_any and not npz_all: - break - - if not npz_any: - use_npz_latents = False - print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") - elif not npz_all: - use_npz_latents = False - print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") - if flip_aug_in_subset: - print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") - # else: - # print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") - - # check min/max bucket size - sizes = set() - resos = set() - for image_info in self.image_data.values(): - if image_info.image_size is None: - sizes = None # not calculated - break - sizes.add(image_info.image_size[0]) - sizes.add(image_info.image_size[1]) - resos.add(tuple(image_info.image_size)) - - if sizes is None: - if use_npz_latents: - use_npz_latents = False - print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します") - - assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" - - self.enable_bucket = enable_bucket - if self.enable_bucket: - self.min_bucket_reso = min_bucket_reso - self.max_bucket_reso = max_bucket_reso - self.bucket_reso_steps = bucket_reso_steps - self.bucket_no_upscale = bucket_no_upscale - else: - if not enable_bucket: - print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") - print("using bucket info in metadata / メタデータ内のbucket情報を使います") - self.enable_bucket = True - - assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません" - - # bucket情報を初期化しておく、make_bucketsで再作成しない - self.bucket_manager = BucketManager(False, None, None, None, None) - self.bucket_manager.set_predefined_resos(resos) - - # npz情報をきれいにしておく - if not use_npz_latents: - for image_info in self.image_data.values(): - image_info.latents_npz = image_info.latents_npz_flipped = None - - def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): - base_name = os.path.splitext(image_key)[0] - npz_file_norm = base_name + '.npz' - - if os.path.exists(npz_file_norm): - # image_key is full path - npz_file_flip = base_name + '_flip.npz' - if not os.path.exists(npz_file_flip): - npz_file_flip = None - return npz_file_norm, npz_file_flip - - # image_key is relative path - npz_file_norm = os.path.join(subset.image_dir, image_key + '.npz') - npz_file_flip = os.path.join(subset.image_dir, image_key + '_flip.npz') - - if not os.path.exists(npz_file_norm): - npz_file_norm = None - npz_file_flip = None - elif not os.path.exists(npz_file_flip): - npz_file_flip = None - - return npz_file_norm, npz_file_flip + return npz_file_norm, npz_file_flip # behave as Dataset mock class DatasetGroup(torch.utils.data.ConcatDataset): - def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): - self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] + def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): + self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] - super().__init__(datasets) + super().__init__(datasets) - self.image_data = {} - self.num_train_images = 0 - self.num_reg_images = 0 + self.image_data = {} + self.num_train_images = 0 + self.num_reg_images = 0 - # simply concat together - # TODO: handling image_data key duplication among dataset - # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset. - for dataset in datasets: - self.image_data.update(dataset.image_data) - self.num_train_images += dataset.num_train_images - self.num_reg_images += dataset.num_reg_images + # simply concat together + # TODO: handling image_data key duplication among dataset + # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset. + for dataset in datasets: + self.image_data.update(dataset.image_data) + self.num_train_images += dataset.num_train_images + self.num_reg_images += dataset.num_reg_images - def add_replacement(self, str_from, str_to): - for dataset in self.datasets: - dataset.add_replacement(str_from, str_to) + def add_replacement(self, str_from, str_to): + for dataset in self.datasets: + dataset.add_replacement(str_from, str_to) - # def make_buckets(self): - # for dataset in self.datasets: - # dataset.make_buckets() + # def make_buckets(self): + # for dataset in self.datasets: + # dataset.make_buckets() - def cache_latents(self, vae): - for i, dataset in enumerate(self.datasets): - print(f"[Dataset {i}]") - dataset.cache_latents(vae) + def cache_latents(self, vae): + for i, dataset in enumerate(self.datasets): + print(f"[Dataset {i}]") + dataset.cache_latents(vae) - def is_latent_cacheable(self) -> bool: - return all([dataset.is_latent_cacheable() for dataset in self.datasets]) + def is_latent_cacheable(self) -> bool: + return all([dataset.is_latent_cacheable() for dataset in self.datasets]) - def set_current_epoch(self, epoch): - for dataset in self.datasets: - dataset.set_current_epoch(epoch) + def set_current_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_current_epoch(epoch) - def disable_token_padding(self): - for dataset in self.datasets: - dataset.disable_token_padding() + def disable_token_padding(self): + for dataset in self.datasets: + dataset.disable_token_padding() def debug_dataset(train_dataset, show_input_ids=False): - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("Escape for exit. / Escキーで中断、終了します") + print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + print("Escape for exit. / Escキーで中断、終了します") - train_dataset.set_current_epoch(1) - k = 0 - indices = list(range(len(train_dataset))) - random.shuffle(indices) - for i, idx in enumerate(indices): - example = train_dataset[idx] - if example['latents'] is not None: - print(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])): - print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') - if show_input_ids: - print(f"input ids: {iid}") - if example['images'] is not None: - im = example['images'][j] - print(f"image size: {im.size()}") - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - if os.name == 'nt': # only windows - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == 27 or (example['images'] is None and i >= 8): - break + train_dataset.set_current_epoch(1) + k = 0 + indices = list(range(len(train_dataset))) + random.shuffle(indices) + for i, idx in enumerate(indices): + example = train_dataset[idx] + if example["latents"] is not None: + print(f"sample has latents from npz file: {example['latents'].size()}") + for j, (ik, cap, lw, iid) in enumerate( + zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) + ): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') + if show_input_ids: + print(f"input ids: {iid}") + if example["images"] is not None: + im = example["images"][j] + print(f"image size: {im.size()}") + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + if os.name == "nt": # only windows + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27: + break + if k == 27 or (example["images"] is None and i >= 8): + break def glob_images(directory, base="*"): - img_paths = [] - for ext in IMAGE_EXTENSIONS: - if base == '*': - img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) - else: - img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) - # img_paths = list(set(img_paths)) # 重複を排除 - # img_paths.sort() - return img_paths + img_paths = [] + for ext in IMAGE_EXTENSIONS: + if base == "*": + img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) + else: + img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) + # img_paths = list(set(img_paths)) # 重複を排除 + # img_paths.sort() + return img_paths def glob_images_pathlib(dir_path, recursive): - image_paths = [] - if recursive: - for ext in IMAGE_EXTENSIONS: - image_paths += list(dir_path.rglob('*' + ext)) - else: - for ext in IMAGE_EXTENSIONS: - image_paths += list(dir_path.glob('*' + ext)) - # image_paths = list(set(image_paths)) # 重複を排除 - # image_paths.sort() - return image_paths + image_paths = [] + if recursive: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.rglob("*" + ext)) + else: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.glob("*" + ext)) + # image_paths = list(set(image_paths)) # 重複を排除 + # image_paths.sort() + return image_paths + # endregion @@ -1165,86 +1333,97 @@ EPSILON = 1e-6 def exists(val): - return val is not None + return val is not None def default(val, d): - return val if exists(val) else d + return val if exists(val) else d def model_hash(filename): - """Old model hash used by stable-diffusion-webui""" - try: - with open(filename, "rb") as file: - m = hashlib.sha256() + """Old model hash used by stable-diffusion-webui""" + try: + with open(filename, "rb") as file: + m = hashlib.sha256() - file.seek(0x100000) - m.update(file.read(0x10000)) - return m.hexdigest()[0:8] - except FileNotFoundError: - return 'NOFILE' + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" def calculate_sha256(filename): - """New model hash used by stable-diffusion-webui""" - hash_sha256 = hashlib.sha256() - blksize = 1024 * 1024 + """New model hash used by stable-diffusion-webui""" + try: + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 - with open(filename, "rb") as f: - for chunk in iter(lambda: f.read(blksize), b""): - hash_sha256.update(chunk) + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) - return hash_sha256.hexdigest() + return hash_sha256.hexdigest() + except FileNotFoundError: + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" def precalculate_safetensors_hashes(tensors, metadata): - """Precalculate the model hashes needed by sd-webui-additional-networks to - save time on indexing the model later.""" + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" - # Because writing user metadata to the file can change the result of - # sd_models.model_hash(), only retain the training metadata for purposes of - # calculating the hash, as they are meant to be immutable - metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} - bytes = safetensors.torch.save(tensors, metadata) - b = BytesIO(bytes) + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) - model_hash = addnet_hash_safetensors(b) - legacy_hash = addnet_hash_legacy(b) - return model_hash, legacy_hash + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash def addnet_hash_legacy(b): - """Old model hash used by sd-webui-additional-networks for .safetensors format files""" - m = hashlib.sha256() + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() - b.seek(0x100000) - m.update(b.read(0x10000)) - return m.hexdigest()[0:8] + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] def addnet_hash_safetensors(b): - """New model hash used by sd-webui-additional-networks for .safetensors format files""" - hash_sha256 = hashlib.sha256() - blksize = 1024 * 1024 + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 - b.seek(0) - header = b.read(8) - n = int.from_bytes(header, "little") + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") - offset = n + 8 - b.seek(offset) - for chunk in iter(lambda: b.read(blksize), b""): - hash_sha256.update(chunk) + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) - return hash_sha256.hexdigest() + return hash_sha256.hexdigest() def get_git_revision_hash() -> str: - try: - return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip() - except: - return "(unknown)" + try: + return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip() + except: + return "(unknown)" # flash attention forwards and backwards @@ -1253,422 +1432,688 @@ def get_git_revision_hash() -> str: class FlashAttentionFunction(torch.autograd.function.Function): - @ staticmethod - @ torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """ Algorithm 2 in the paper """ + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - scale = (q.shape[-1] ** -0.5) + scale = q.shape[-1] ** -0.5 - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.) + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) + exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - return o + return o - @ staticmethod - @ torch.no_grad() - def backward(ctx, do): - """ Algorithm 4 in the paper """ + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors - device = q.device + device = q.device - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2) - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) - exp_attn_weights = torch.exp(attn_weights - mc) + exp_attn_weights = torch.exp(attn_weights - mc) - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.) + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) - p = exp_attn_weights / lc + p = exp_attn_weights / lc - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) + dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) + dp = einsum("... i d, ... j d -> ... i j", doc, vc) - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) + dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) - return dq, dk, dv, None, None, None, None + return dq, dk, dv, None, None, None, None def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use FlashAttention (not xformers)") - flash_func = FlashAttentionFunction + print("Replace CrossAttention.forward to use FlashAttention (not xformers)") + flash_func = FlashAttentionFunction - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 - h = self.heads - q = self.to_q(x) + h = self.heads + q = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) + context = context if context is not None else x + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) + out = self.to_out[0](out) + out = self.to_out[1](out) + return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + diffusers.models.attention.CrossAttention.forward = forward_flash_attn def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use xformers") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) - context = default(context, x) - context = context.to(x.dtype) + context = default(context, x) + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - out = rearrange(out, 'b n h d -> b n (h d)', h=h) + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - diffusers.models.attention.CrossAttention.forward = forward_xformers # endregion # region arguments + def add_sd_models_arguments(parser: argparse.ArgumentParser): - # for pretrained models - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, - help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") - parser.add_argument("--tokenizer_cache_dir", type=str, default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)") + # for pretrained models + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) def add_optimizer_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor") + parser.add_argument( + "--optimizer_type", + type=str, + default="", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor", + ) - # backward compatibility - parser.add_argument("--use_8bit_adam", action="store_true", - help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - parser.add_argument("--use_lion_optimizer", action="store_true", - help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)") + # backward compatibility + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)", + ) + parser.add_argument( + "--use_lion_optimizer", + action="store_true", + help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)", + ) - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") - parser.add_argument("--max_grad_norm", default=1.0, type=float, - help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない") + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない" + ) - parser.add_argument("--optimizer_args", type=str, default=None, nargs='*', - help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")") + parser.add_argument( + "--optimizer_args", + type=str, + default=None, + nargs="*", + help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', + ) - parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor") - parser.add_argument("--lr_warmup_steps", type=int, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") - parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1, - help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数") - parser.add_argument("--lr_scheduler_power", type=float, default=1, - help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power") + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") + parser.add_argument( + "--lr_scheduler_args", + type=str, + default=None, + nargs="*", + help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")', + ) + + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor", + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + ) + parser.add_argument( + "--lr_scheduler_num_cycles", + type=int, + default=1, + help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数", + ) + parser.add_argument( + "--lr_scheduler_power", + type=float, + default=1, + help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): - parser.add_argument("--output_dir", type=str, default=None, - help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--output_name", type=str, default=None, - help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") - parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") - parser.add_argument("--save_n_epoch_ratio", type=int, default=None, - help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)") - parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") - parser.add_argument("--save_last_n_epochs_state", type=int, default=None, - help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") - parser.add_argument("--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") - parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") + parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") + parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving / 保存時に精度を変更して保存する", + ) + parser.add_argument( + "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する" + ) + parser.add_argument( + "--save_n_epoch_ratio", + type=int, + default=None, + help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)", + ) + parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") + parser.add_argument( + "--save_last_n_epochs_state", + type=int, + default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)", + ) + parser.add_argument( + "--save_state", + action="store_true", + help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する", + ) + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") - parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") - parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], - help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") - parser.add_argument("--mem_eff_attn", action="store_true", - help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") - parser.add_argument("--xformers", action="store_true", - help="use xformers for CrossAttention / CrossAttentionにxformersを使う") - parser.add_argument("--vae", type=str, default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") + parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") + parser.add_argument( + "--max_token_length", + type=int, + default=None, + choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)", + ) + parser.add_argument( + "--mem_eff_attn", + action="store_true", + help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", + ) + parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + ) - parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--max_train_epochs", type=int, default=None, - help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") - parser.add_argument("--max_data_loader_n_workers", type=int, default=8, - help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") - parser.add_argument("--persistent_data_loader_workers", action="store_true", - help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)") - parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") - parser.add_argument("--gradient_checkpointing", action="store_true", - help="enable gradient checkpointing / grandient checkpointingを有効にする") - parser.add_argument("--gradient_accumulation_steps", type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument("--clip_skip", type=int, default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") - parser.add_argument("--logging_dir", type=str, default=None, - help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--noise_offset", type=float, default=None, - help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)") - parser.add_argument("--lowram", action="store_true", - help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)") + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument( + "--max_train_epochs", + type=int, + default=None, + help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)", + ) + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)", + ) + parser.add_argument( + "--persistent_data_loader_workers", + action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", + ) + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument( + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", + ) + parser.add_argument( + "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + ) + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument( + "--clip_skip", + type=int, + default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)", + ) + parser.add_argument( + "--logging_dir", + type=str, + default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", + ) + parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument( + "--noise_offset", + type=float, + default=None, + help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", + ) + parser.add_argument( + "--lowram", + action="store_true", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + ) - parser.add_argument("--sample_every_n_steps", type=int, default=None, - help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する") - parser.add_argument("--sample_every_n_epochs", type=int, default=None, - help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)") - parser.add_argument("--sample_prompts", type=str, default=None, - help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル") - parser.add_argument('--sample_sampler', type=str, default='ddim', - choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver', - 'dpmsolver++', 'dpmsingle', - 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'], - help=f'sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類') + parser.add_argument( + "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" + ) + parser.add_argument( + "--sample_every_n_epochs", + type=int, + default=None, + help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", + ) + parser.add_argument( + "--sample_prompts", type=str, default=None, help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル" + ) + parser.add_argument( + "--sample_sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類", + ) - if support_dreambooth: - # DreamBooth training - parser.add_argument("--prior_loss_weight", type=float, default=1.0, - help="loss weight for regularization images / 正則化画像のlossの重み") + parser.add_argument( + "--config_file", + type=str, + default=None, + help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す", + ) + parser.add_argument( + "--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する" + ) + + if support_dreambooth: + # DreamBooth training + parser.add_argument( + "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" + ) def verify_training_args(args: argparse.Namespace): - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") -def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool): - # dataset common - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--shuffle_caption", action="store_true", - help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") - parser.add_argument("--keep_tokens", type=int, default=0, - help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)") - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") - parser.add_argument("--face_crop_aug_range", type=str, default=None, - help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") - parser.add_argument("--random_crop", action="store_true", - help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") - parser.add_argument("--debug_dataset", action="store_true", - help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - parser.add_argument("--resolution", type=str, default=None, - help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") - parser.add_argument("--cache_latents", action="store_true", - help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") - parser.add_argument("--enable_bucket", action="store_true", - help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") - parser.add_argument("--bucket_reso_steps", type=int, default=64, - help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します") - parser.add_argument("--bucket_no_upscale", action="store_true", - help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") +def add_dataset_arguments( + parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool +): + # dataset common + parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument( + "--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする" + ) + parser.add_argument( + "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)", + ) + parser.add_argument( + "--keep_tokens", + type=int, + default=0, + help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", + ) + parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") + parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument( + "--face_crop_aug_range", + type=str, + default=None, + help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)", + ) + parser.add_argument( + "--random_crop", + action="store_true", + help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)", + ) + parser.add_argument( + "--debug_dataset", action="store_true", help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)" + ) + parser.add_argument( + "--resolution", + type=str, + default=None, + help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)", + ) + parser.add_argument( + "--cache_latents", + action="store_true", + help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)", + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" + ) + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + parser.add_argument( + "--bucket_reso_steps", + type=int, + default=64, + help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", + ) + parser.add_argument( + "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + ) - if support_caption_dropout: - # Textual Inversion はcaptionのdropoutをsupportしない - # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに - parser.add_argument("--caption_dropout_rate", type=float, default=0.0, - help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合") - parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=0, - help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする") - parser.add_argument("--caption_tag_dropout_rate", type=float, default=0.0, - help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合") + if support_caption_dropout: + # Textual Inversion はcaptionのdropoutをsupportしない + # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに + parser.add_argument( + "--caption_dropout_rate", type=float, default=0.0, help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合" + ) + parser.add_argument( + "--caption_dropout_every_n_epochs", + type=int, + default=0, + help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする", + ) + parser.add_argument( + "--caption_tag_dropout_rate", + type=float, + default=0.0, + help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合", + ) - if support_dreambooth: - # DreamBooth dataset - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") + if support_dreambooth: + # DreamBooth dataset + parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") - if support_caption: - # caption dataset - parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") - parser.add_argument("--dataset_repeats", type=int, default=1, - help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") + if support_caption: + # caption dataset + parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") + parser.add_argument( + "--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数" + ) def add_sd_saving_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], - help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") - parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") + parser.add_argument( + "--save_model_as", + type=str, + default=None, + choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], + help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)", + ) + parser.add_argument( + "--use_safetensors", + action="store_true", + help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)", + ) + + +def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): + if not args.config_file: + return args + + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + + if args.output_config: + # check if config file exists + if os.path.exists(config_path): + print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") + exit(1) + + # convert args to dictionary + args_dict = vars(args) + + # remove unnecessary keys + for key in ["config_file", "output_config"]: + if key in args_dict: + del args_dict[key] + + # get default args from parser + default_args = vars(parser.parse_args([])) + + # remove default values: cannot use args_dict.items directly because it will be changed during iteration + for key, value in list(args_dict.items()): + if key in default_args and value == default_args[key]: + del args_dict[key] + + # convert Path to str in dictionary + for key, value in args_dict.items(): + if isinstance(value, pathlib.Path): + args_dict[key] = str(value) + + # convert to toml and output to file + with open(config_path, "w") as f: + toml.dump(args_dict, f) + + print(f"Saved config file / 設定ファイルを保存しました: {config_path}") + exit(0) + + if not os.path.exists(config_path): + print(f"{config_path} not found.") + exit(1) + + print(f"Loading settings from {config_path}...") + with open(config_path, "r") as f: + config_dict = toml.load(f) + + # combine all sections into one + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + # if value is not dict, save key and value as is + if not isinstance(section_dict, dict): + ignore_nesting_dict[section_name] = section_dict + continue + + # if value is dict, save all key and value into one dict + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = os.path.splitext(args.config_file)[0] + print(args.config_file) + + return args + # endregion @@ -1676,178 +2121,191 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor" - optimizer_type = args.optimizer_type - if args.use_8bit_adam: - assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています" - assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています" - optimizer_type = "AdamW8bit" + optimizer_type = args.optimizer_type + if args.use_8bit_adam: + assert ( + not args.use_lion_optimizer + ), "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています" + assert ( + optimizer_type is None or optimizer_type == "" + ), "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています" + optimizer_type = "AdamW8bit" - elif args.use_lion_optimizer: - assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています" - optimizer_type = "Lion" + elif args.use_lion_optimizer: + assert ( + optimizer_type is None or optimizer_type == "" + ), "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています" + optimizer_type = "Lion" - if optimizer_type is None or optimizer_type == "": - optimizer_type = "AdamW" - optimizer_type = optimizer_type.lower() + if optimizer_type is None or optimizer_type == "": + optimizer_type = "AdamW" + optimizer_type = optimizer_type.lower() - # 引数を分解する:boolとfloat、tupleのみ対応 - optimizer_kwargs = {} - if args.optimizer_args is not None and len(args.optimizer_args) > 0: - for arg in args.optimizer_args: - key, value = arg.split('=') + # 引数を分解する + optimizer_kwargs = {} + if args.optimizer_args is not None and len(args.optimizer_args) > 0: + for arg in args.optimizer_args: + key, value = arg.split("=") + value = ast.literal_eval(value) - value = value.split(",") - for i in range(len(value)): - if value[i].lower() == "true" or value[i].lower() == "false": - value[i] = (value[i].lower() == "true") + # value = value.split(",") + # for i in range(len(value)): + # if value[i].lower() == "true" or value[i].lower() == "false": + # value[i] = value[i].lower() == "true" + # else: + # value[i] = ast.float(value[i]) + # if len(value) == 1: + # value = value[0] + # else: + # value = tuple(value) + + optimizer_kwargs[key] = value + # print("optkwargs:", optimizer_kwargs) + + lr = args.learning_rate + + if optimizer_type == "AdamW8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov8bit".lower(): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print( + f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" + ) + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = bnb.optim.SGD8bit + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "Lion".lower(): + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") + print(f"use Lion optimizer | {optimizer_kwargs}") + optimizer_class = lion_pytorch.Lion + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov".lower(): + print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = torch.optim.SGD + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "DAdaptation".lower(): + try: + import dadaptation + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + ) + print("recommend option: lr=1.0 / 推奨は1.0です") + if lr_count > 1: + print( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + ) + + optimizer_class = dadaptation.DAdaptAdam + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Adafactor".lower(): + # 引数を確認して適宜補正する + if "relative_step" not in optimizer_kwargs: + optimizer_kwargs["relative_step"] = True # default + if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): + print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") + optimizer_kwargs["relative_step"] = True + print(f"use Adafactor optimizer | {optimizer_kwargs}") + + if optimizer_kwargs["relative_step"]: + print(f"relative_step is true / relative_stepがtrueです") + if lr != 0.0: + print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + args.learning_rate = None + + # trainable_paramsがgroupだった時の処理:lrを削除する + if type(trainable_params) == list and type(trainable_params[0]) == dict: + has_group_lr = False + for group in trainable_params: + p = group.pop("lr", None) + has_group_lr = has_group_lr or (p is not None) + + if has_group_lr: + # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない + print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") + args.unet_lr = None + args.text_encoder_lr = None + + if args.lr_scheduler != "adafactor": + print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど + + lr = None else: - value[i] = float(value[i]) - if len(value) == 1: - value = value[0] - else: - value = tuple(value) + if args.max_grad_norm != 0.0: + print( + f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" + ) + if args.lr_scheduler != "constant_with_warmup": + print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: + print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") - optimizer_kwargs[key] = value - # print("optkwargs:", optimizer_kwargs) + optimizer_class = transformers.optimization.Adafactor + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - lr = args.learning_rate + elif optimizer_type == "AdamW".lower(): + print(f"use AdamW optimizer | {optimizer_kwargs}") + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - if optimizer_type == "AdamW8bit".lower(): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") - optimizer_class = bnb.optim.AdamW8bit - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "SGDNesterov8bit".lower(): - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") - if "momentum" not in optimizer_kwargs: - print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します") - optimizer_kwargs["momentum"] = 0.9 - - optimizer_class = bnb.optim.SGD8bit - optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - - elif optimizer_type == "Lion".lower(): - try: - import lion_pytorch - except ImportError: - raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") - print(f"use Lion optimizer | {optimizer_kwargs}") - optimizer_class = lion_pytorch.Lion - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "SGDNesterov".lower(): - print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") - if "momentum" not in optimizer_kwargs: - print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します") - optimizer_kwargs["momentum"] = 0.9 - - optimizer_class = torch.optim.SGD - optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - - elif optimizer_type == "DAdaptation".lower(): - try: - import dadaptation - except ImportError: - raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") - - actual_lr = lr - lr_count = 1 - if type(trainable_params) == list and type(trainable_params[0]) == dict: - lrs = set() - actual_lr = trainable_params[0].get("lr", actual_lr) - for group in trainable_params: - lrs.add(group.get("lr", actual_lr)) - lr_count = len(lrs) - - if actual_lr <= 0.1: - print( - f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}') - print('recommend option: lr=1.0 / 推奨は1.0です') - if lr_count > 1: - print( - f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}") - - optimizer_class = dadaptation.DAdaptAdam - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "Adafactor".lower(): - # 引数を確認して適宜補正する - if "relative_step" not in optimizer_kwargs: - optimizer_kwargs["relative_step"] = True # default - if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): - print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします") - optimizer_kwargs["relative_step"] = True - print(f"use Adafactor optimizer | {optimizer_kwargs}") - - if optimizer_kwargs["relative_step"]: - print(f"relative_step is true / relative_stepがtrueです") - if lr != 0.0: - print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") - args.learning_rate = None - - # trainable_paramsがgroupだった時の処理:lrを削除する - if type(trainable_params) == list and type(trainable_params[0]) == dict: - has_group_lr = False - for group in trainable_params: - p = group.pop("lr", None) - has_group_lr = has_group_lr or (p is not None) - - if has_group_lr: - # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない - print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") - args.unet_lr = None - args.text_encoder_lr = None - - if args.lr_scheduler != "adafactor": - print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") - args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど - - lr = None else: - if args.max_grad_norm != 0.0: - print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません") - if args.lr_scheduler != "constant_with_warmup": - print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") - if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: - print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + # 任意のoptimizerを使う + optimizer_type = args.optimizer_type # lowerでないやつ(微妙) + print(f"use {optimizer_type} | {optimizer_kwargs}") + if "." not in optimizer_type: + optimizer_module = torch.optim + else: + values = optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + optimizer_type = values[-1] - optimizer_class = transformers.optimization.Adafactor - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + optimizer_class = getattr(optimizer_module, optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - elif optimizer_type == "AdamW".lower(): - print(f"use AdamW optimizer | {optimizer_kwargs}") - optimizer_class = torch.optim.AdamW - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ + optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) - else: - # 任意のoptimizerを使う - optimizer_type = args.optimizer_type # lowerでないやつ(微妙) - print(f"use {optimizer_type} | {optimizer_kwargs}") - if "." not in optimizer_type: - optimizer_module = torch.optim - else: - values = optimizer_type.split(".") - optimizer_module = importlib.import_module(".".join(values[:-1])) - optimizer_type = values[-1] - - optimizer_class = getattr(optimizer_module, optimizer_type) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ - optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) - - return optimizer_name, optimizer_args, optimizer + return optimizer_name, optimizer_args, optimizer # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler @@ -1856,540 +2314,599 @@ def get_optimizer(args, trainable_params): # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts -def get_scheduler_fix( - name: Union[str, SchedulerType], - optimizer: Optimizer, - num_warmup_steps: Optional[int] = None, - num_training_steps: Optional[int] = None, - num_cycles: int = 1, - power: float = 1.0, -): - """ - Unified API to get any scheduler from its name. - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_cycles (`int`, *optional*): - The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. - power (`float`, *optional*, defaults to 1.0): - Power factor. See `POLYNOMIAL` scheduler - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - """ - if name.startswith("adafactor"): - assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" - initial_lr = float(name.split(':')[1]) - # print("adafactor scheduler init lr", initial_lr) - return transformers.optimization.AdafactorSchedule(optimizer, initial_lr) +def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): + """ + Unified API to get any scheduler from its name. + """ + name = args.lr_scheduler + num_warmup_steps = args.lr_warmup_steps + num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps + num_cycles = args.lr_scheduler_num_cycles + power = args.lr_scheduler_power - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) + lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs + if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: + for arg in args.lr_scheduler_args: + key, value = arg.split("=") - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + value = ast.literal_eval(value) + # value = value.split(",") + # for i in range(len(value)): + # if value[i].lower() == "true" or value[i].lower() == "false": + # value[i] = value[i].lower() == "true" + # else: + # value[i] = ast.literal_eval(value[i]) + # if len(value) == 1: + # value = value[0] + # else: + # value = list(value) # some may use list? - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + lr_scheduler_kwargs[key] = value - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + # using any lr_scheduler from other library + if args.lr_scheduler_type: + lr_scheduler_type = args.lr_scheduler_type + print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + if "." not in lr_scheduler_type: # default to use torch.optim + lr_scheduler_module = torch.optim.lr_scheduler + else: + values = lr_scheduler_type.split(".") + lr_scheduler_module = importlib.import_module(".".join(values[:-1])) + lr_scheduler_type = values[-1] + lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) + lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + return lr_scheduler - if name == SchedulerType.COSINE_WITH_RESTARTS: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles - ) + if name.startswith("adafactor"): + assert ( + type(optimizer) == transformers.optimization.Adafactor + ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" + initial_lr = float(name.split(":")[1]) + # print("adafactor scheduler init lr", initial_lr) + return transformers.optimization.AdafactorSchedule(optimizer, initial_lr) - if name == SchedulerType.POLYNOMIAL: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power - ) + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): - # backward compatibility - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - args.caption_extention = None + # backward compatibility + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + args.caption_extention = None - # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" - if args.resolution is not None: - args.resolution = tuple([int(r) for r in args.resolution.split(',')]) - if len(args.resolution) == 1: - args.resolution = (args.resolution[0], args.resolution[0]) - assert len(args.resolution) == 2, \ - f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" + if args.resolution is not None: + args.resolution = tuple([int(r) for r in args.resolution.split(",")]) + if len(args.resolution) == 1: + args.resolution = (args.resolution[0], args.resolution[0]) + assert ( + len(args.resolution) == 2 + ), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" - if args.face_crop_aug_range is not None: - args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) - assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \ - f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" - else: - args.face_crop_aug_range = None + if args.face_crop_aug_range is not None: + args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")]) + assert ( + len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1] + ), f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" + else: + args.face_crop_aug_range = None - if support_metadata: - if args.in_json is not None and (args.color_aug or args.random_crop): - print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます") + if support_metadata: + if args.in_json is not None and (args.color_aug or args.random_crop): + print( + f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" + ) def load_tokenizer(args: argparse.Namespace): - print("prepare tokenizer") - original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH + print("prepare tokenizer") + original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH - tokenizer: CLIPTokenizer = None - if args.tokenizer_cache_dir: - local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_')) - if os.path.exists(local_tokenizer_path): - print(f"load tokenizer from cache: {local_tokenizer_path}") - tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + print(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 - if tokenizer is None: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(original_path) + if tokenizer is None: + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(original_path) - if hasattr(args, "max_token_length") and args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") + if hasattr(args, "max_token_length") and args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") - if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): - print(f"save Tokenizer to cache: {local_tokenizer_path}") - tokenizer.save_pretrained(local_tokenizer_path) + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + print(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) - return tokenizer + return tokenizer def prepare_accelerator(args: argparse.Namespace): - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = "tensorboard" - log_prefix = "" if args.log_prefix is None else args.log_prefix - logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - log_with=log_with, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=log_with, + logging_dir=logging_dir, + ) - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False + # accelerateの互換性問題を解決する + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) - return accelerator, unwrap_model + return accelerator, unwrap_model def prepare_dtype(args: argparse.Namespace): - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 - return weight_dtype, save_dtype + return weight_dtype, save_dtype def load_target_model(args: argparse.Namespace, weight_dtype): - name_or_path = args.pretrained_model_name_or_path - name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path - load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers - if load_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path) - else: - print("load Diffusers pretrained models") - try: - pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) - except EnvironmentError as ex: - print( - f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}") - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - del pipe + name_or_path = args.pretrained_model_name_or_path + name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path + load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + if load_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path) + else: + print("load Diffusers pretrained models") + try: + pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) + except EnvironmentError as ex: + print( + f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" + ) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + del pipe - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + print("additional VAE loaded") - return text_encoder, vae, unet, load_stable_diffusion_format + return text_encoder, vae, unet, load_stable_diffusion_format def patch_accelerator_for_fp16_training(accelerator): - org_unscale_grads = accelerator.scaler._unscale_grads_ + org_unscale_grads = accelerator.scaler._unscale_grads_ - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): - # with no_token_padding, the length is not max length, return result immediately - if input_ids.size()[-1] != tokenizer.model_max_length: - return text_encoder(input_ids)[0] + # with no_token_padding, the length is not max length, return result immediately + if input_ids.size()[-1] != tokenizer.model_max_length: + return text_encoder(input_ids)[0] - b_size = input_ids.size()[0] - input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 + b_size = input_ids.size()[0] + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 - if args.clip_skip is None: - encoder_hidden_states = text_encoder(input_ids)[0] - else: - enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - - # bs*3, 77, 768 or 1024 - encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - - if args.max_token_length is not None: - if args.v2: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン - chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか - encoder_hidden_states = torch.cat(states_list, dim=1) + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] else: - # v1: ... の三連を ... へ戻す - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # - encoder_hidden_states = torch.cat(states_list, dim=1) + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - if weight_dtype is not None: - # this is required for additional network training - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - return encoder_hidden_states + if args.max_token_length is not None: + if args.v2: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + + return encoder_hidden_states def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): - model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name - ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") - return model_name, ckpt_name + model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") + return model_name, ckpt_name def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): - saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs - if saving: - os.makedirs(args.output_dir, exist_ok=True) - save_func() + saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs + if saving: + os.makedirs(args.output_dir, exist_ok=True) + save_func() - if args.save_last_n_epochs is not None: - remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs - remove_old_func(remove_epoch_no) - return saving + if args.save_last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + remove_old_func(remove_epoch_no) + return saving -def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae): - epoch_no = epoch + 1 - model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) +def save_sd_model_on_epoch_end( + args: argparse.Namespace, + accelerator, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + text_encoder, + unet, + vae, +): + epoch_no = epoch + 1 + model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) - if save_stable_diffusion_format: - def save_sd(): - ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"saving checkpoint: {ckpt_file}") - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, - src_path, epoch_no, global_step, save_dtype, vae) + if save_stable_diffusion_format: - def remove_sd(old_epoch_no): - _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) + def save_sd(): + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae + ) - save_func = save_sd - remove_old_func = remove_sd - else: - def save_du(): - out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) - print(f"saving model: {out_dir}") - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, - src_path, vae=vae, use_safetensors=use_safetensors) + def remove_sd(old_epoch_no): + _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) - def remove_du(old_epoch_no): - out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) - if os.path.exists(out_dir_old): - print(f"removing old model: {out_dir_old}") - shutil.rmtree(out_dir_old) + save_func = save_sd + remove_old_func = remove_sd + else: - save_func = save_du - remove_old_func = remove_du + def save_du(): + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + print(f"saving model: {out_dir}") + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) - saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) - if saving and args.save_state: - save_state_on_epoch_end(args, accelerator, model_name, epoch_no) + def remove_du(old_epoch_no): + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) + if os.path.exists(out_dir_old): + print(f"removing old model: {out_dir_old}") + shutil.rmtree(out_dir_old) + + save_func = save_du + remove_old_func = remove_du + + saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) + if saving and args.save_state: + save_state_on_epoch_end(args, accelerator, model_name, epoch_no) def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) - last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs - if last_n_epochs is not None: - remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs - state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) - if os.path.exists(state_dir_old): - print(f"removing old state: {state_dir_old}") - shutil.rmtree(state_dir_old) + last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs + if last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + print(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) -def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae): - model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name +def save_sd_model_on_train_end( + args: argparse.Namespace, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + text_encoder, + unet, + vae, +): + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - if save_stable_diffusion_format: - os.makedirs(args.output_dir, exist_ok=True) + if save_stable_diffusion_format: + os.makedirs(args.output_dir, exist_ok=True) - ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") - ckpt_file = os.path.join(args.output_dir, ckpt_name) + ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") + ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, - src_path, epoch, global_step, save_dtype, vae) - else: - out_dir = os.path.join(args.output_dir, model_name) - os.makedirs(out_dir, exist_ok=True) + print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae + ) + else: + out_dir = os.path.join(args.output_dir, model_name) + os.makedirs(out_dir, exist_ok=True) - print(f"save trained model as Diffusers to {out_dir}") - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, - src_path, vae=vae, use_safetensors=use_safetensors) + print(f"save trained model as Diffusers to {out_dir}") + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) def save_state_on_train_end(args: argparse.Namespace, accelerator): - print("saving last state.") - os.makedirs(args.output_dir, exist_ok=True) - model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) # scheduler: SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = 'scaled_linear' +SCHEDLER_SCHEDULE = "scaled_linear" -def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None): - """ - 生成に使っている Diffusers の Pipeline がデフォルトなので、プロンプトの重みづけには対応していない - clip skipは対応した - """ - if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: - return - if args.sample_every_n_epochs is not None: - # sample_every_n_steps は無視する - if epoch is None or epoch % args.sample_every_n_epochs != 0: - return - else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch - return +def sample_images( + accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None +): + """ + StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した + """ + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: + return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return - print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): - print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") - return + print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") + if not os.path.isfile(args.sample_prompts): + print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") + return - org_vae_device = vae.device # CPUにいるはず - vae.to(device) + org_vae_device = vae.device # CPUにいるはず + vae.to(device) - # clip skip 対応のための wrapper を作る - if args.clip_skip is None: - text_encoder_or_wrapper = text_encoder - else: - class Wrapper(): - def __init__(self, tenc) -> None: - self.tenc = tenc - self.config = {} - super().__init__() + # read prompts + with open(args.sample_prompts, "rt", encoding="utf-8") as f: + prompts = f.readlines() - def __call__(self, input_ids, attention_mask): - enc_out = self.tenc(input_ids, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = self.tenc.text_model.final_layer_norm(encoder_hidden_states) - pooled_output = enc_out['pooler_output'] - return encoder_hidden_states, pooled_output # 1st output is only used + # schedulerを用意する + sched_init_args = {} + if args.sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + elif args.sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sample_sampler + elif args.sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif args.sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler - text_encoder_or_wrapper = Wrapper(text_encoder) + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" - # read prompts - with open(args.sample_prompts, 'rt', encoding='utf-8') as f: - prompts = f.readlines() + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) - # schedulerを用意する - sched_init_args = {} - if args.sample_sampler == "ddim": - scheduler_cls = DDIMScheduler - elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - elif args.sample_sampler == "pndm": - scheduler_cls = PNDMScheduler - elif args.sample_sampler == 'lms' or args.sample_sampler == 'k_lms': - scheduler_cls = LMSDiscreteScheduler - elif args.sample_sampler == 'euler' or args.sample_sampler == 'k_euler': - scheduler_cls = EulerDiscreteScheduler - elif args.sample_sampler == 'euler_a' or args.sample_sampler == 'k_euler_a': - scheduler_cls = EulerAncestralDiscreteScheduler - elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args['algorithm_type'] = args.sample_sampler - elif args.sample_sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - elif args.sample_sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - elif args.sample_sampler == 'dpm_2' or args.sample_sampler == 'k_dpm_2': - scheduler_cls = KDPM2DiscreteScheduler - elif args.sample_sampler == 'dpm_2_a' or args.sample_sampler == 'k_dpm_2_a': - scheduler_cls = KDPM2AncestralDiscreteScheduler - else: - scheduler_cls = DDIMScheduler + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + scheduler.config.clip_sample = True - if args.v_parameterization: - sched_init_args['prediction_type'] = 'v_prediction' + pipeline = StableDiffusionLongPromptWeightingPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + clip_skip=args.clip_skip, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + pipeline.to(device) - scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args) + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - scheduler.config.clip_sample = True + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() - pipeline = StableDiffusionPipeline(text_encoder=text_encoder_or_wrapper, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False) - pipeline.to(device) + with torch.no_grad(): + with accelerator.autocast(): + for i, prompt in enumerate(prompts): + if not accelerator.is_main_process: + continue + prompt = prompt.strip() + if len(prompt) == 0 or prompt[0] == "#": + continue - save_dir = args.output_dir + "/sample" - os.makedirs(save_dir, exist_ok=True) + # subset of gen_img_diffusers + prompt_args = prompt.split(" --") + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue - rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue - with torch.no_grad(): - with accelerator.autocast(): - for i, prompt in enumerate(prompts): - if not accelerator.is_main_process: - continue - prompt = prompt.strip() - if len(prompt) == 0 or prompt[0] == '#': - continue + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue - # subset of gen_img_diffusers - prompt_args = prompt.split(' --') - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - for parg in prompt_args: - try: - m = re.match(r'w (\d+)', parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue - m = re.match(r'h (\d+)', parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue - m = re.match(r'd (\d+)', parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue - m = re.match(r's (\d+)', parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) - m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) - m = re.match(r'n (.+)', parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"prompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + image = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + ).images[0] - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + ) - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + image.save(os.path.join(save_dir, img_filename)) - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() - ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + torch.set_rng_state(rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + vae.to(org_vae_device) - image.save(os.path.join(save_dir, img_filename)) - - # clear pipeline and cache to reduce vram usage - del pipeline - torch.cuda.empty_cache() - - torch.set_rng_state(rng_state) - torch.cuda.set_rng_state(cuda_rng_state) - vae.to(org_vae_device) # endregion @@ -2397,24 +2914,24 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v class ImageLoadingDataset(torch.utils.data.Dataset): - def __init__(self, image_paths): - self.images = image_paths + def __init__(self, image_paths): + self.images = image_paths - def __len__(self): - return len(self.images) + def __len__(self): + return len(self.images) - def __getitem__(self, idx): - img_path = self.images[idx] + def __getitem__(self, idx): + img_path = self.images[idx] - try: - image = Image.open(img_path).convert("RGB") - # convert to tensor temporarily so dataloader will accept it - tensor_pil = transforms.functional.pil_to_tensor(image) - except Exception as e: - print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") - return None + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor_pil = transforms.functional.pil_to_tensor(image) + except Exception as e: + print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None - return (tensor_pil, img_path) + return (tensor_pil, img_path) # endregion diff --git a/library/wd14_caption_gui.py b/library/wd14_caption_gui.py index 99d22d8..d8d8205 100644 --- a/library/wd14_caption_gui.py +++ b/library/wd14_caption_gui.py @@ -5,7 +5,9 @@ from .common_gui import get_folder_path import os -def caption_images(train_data_dir, caption_extension, batch_size, thresh, replace_underscores): +def caption_images( + train_data_dir, caption_extension, batch_size, thresh, replace_underscores +): # Check for caption_text_input # if caption_text_input == "": # msgbox("Caption text is missing...") @@ -76,7 +78,7 @@ def gradio_wd14_caption_gui_tab(): batch_size = gr.Number( value=1, label='Batch size', interactive=True ) - + replace_underscores = gr.Checkbox( label='Replace underscores in filenames with spaces', value=False, @@ -87,6 +89,12 @@ def gradio_wd14_caption_gui_tab(): caption_button.click( caption_images, - inputs=[train_data_dir, caption_extension, batch_size, thresh, replace_underscores], + inputs=[ + train_data_dir, + caption_extension, + batch_size, + thresh, + replace_underscores, + ], show_progress=False, ) diff --git a/lora_gui.py b/lora_gui.py index a043c4b..a50a661 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -4,6 +4,7 @@ # v3.1: Adding captionning of images to utilities import gradio as gr +import easygui import json import math import os @@ -26,6 +27,7 @@ from library.common_gui import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, + check_if_model_exist, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -120,7 +122,8 @@ def save_configuration( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -236,15 +239,16 @@ def open_configuration( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): # Get list of function parameters and values parameters = list(locals().items()) - + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - + if ask_for_file: file_path = get_file_path(file_path) @@ -342,10 +346,11 @@ def train_model( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): print_only_bool = True if print_only.get('label') == 'True' else False - + if pretrained_model_name_or_path == '': msgbox('Source model information is missing') return @@ -380,6 +385,9 @@ def train_model( ) stop_text_encoder_training_pct = 0 + if check_if_model_exist(output_name, output_dir, save_model_as): + return + # If string is empty set string to 0. if text_encoder_lr == '': text_encoder_lr = 0 @@ -417,7 +425,7 @@ def train_model( or f.endswith('.webp') ] ) - + print(f'Folder {folder}: {num_images} images found') # Calculate the total number of steps for this folder @@ -425,7 +433,7 @@ def train_model( # Print the result print(f'Folder {folder}: {steps} steps') - + total_steps += steps # calculate max_train_steps @@ -492,9 +500,7 @@ def train_model( ) return run_cmd += f' --network_module=lycoris.kohya' - run_cmd += ( - f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"' - ) + run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=lora"' if LoRA_type == 'LyCORIS/LoHa': try: import lycoris @@ -504,9 +510,7 @@ def train_model( ) return run_cmd += f' --network_module=lycoris.kohya' - run_cmd += ( - f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"' - ) + run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"' if LoRA_type == 'Kohya LoCon': run_cmd += f' --network_module=networks.lora' run_cmd += ( @@ -595,8 +599,10 @@ def train_model( output_dir, ) - if print_only_bool: - print('\033[93m\nHere is the trainer command as a reference. It will not be executed:\033[0m\n') + if print_only_bool: + print( + '\033[93m\nHere is the trainer command as a reference. It will not be executed:\033[0m\n' + ) print('\033[96m' + run_cmd + '\033[0m\n') else: print(run_cmd) @@ -611,7 +617,9 @@ def train_model( if not last_dir.is_dir(): # Copy inference model for v2 if required - save_inference_file(output_dir, v2, v_parameterization, output_name) + save_inference_file( + output_dir, v2, v_parameterization, output_name + ) def lora_tab( @@ -811,7 +819,12 @@ def lora_tab( # Show of hide LoCon conv settings depending on LoRA type selection def LoRA_type_change(LoRA_type): print('LoRA type changed...') - if LoRA_type == 'LoCon' or LoRA_type == 'Kohya LoCon' or LoRA_type == 'LyCORIS/LoHa' or LoRA_type == 'LyCORIS/LoCon': + if ( + LoRA_type == 'LoCon' + or LoRA_type == 'Kohya LoCon' + or LoRA_type == 'LyCORIS/LoHa' + or LoRA_type == 'LyCORIS/LoCon' + ): return gr.Group.update(visible=True) else: return gr.Group.update(visible=False) @@ -876,7 +889,8 @@ def lora_tab( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, - noise_offset,additional_parameters, + noise_offset, + additional_parameters, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -908,7 +922,7 @@ def lora_tab( gradio_verify_lora_tab() button_run = gr.Button('Train model', variant='primary') - + button_print = gr.Button('Print training command') # Setup gradio tensorboard buttons @@ -992,7 +1006,8 @@ def lora_tab( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ] button_open_config.click( @@ -1001,7 +1016,7 @@ def lora_tab( outputs=[config_file_name] + settings_list + [LoCon_row], show_progress=False, ) - + button_load_config.click( open_configuration, inputs=[dummy_db_false, config_file_name] + settings_list, @@ -1028,7 +1043,7 @@ def lora_tab( inputs=[dummy_db_false] + settings_list, show_progress=False, ) - + button_print.click( train_model, inputs=[dummy_db_true] + settings_list, diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index ed3c33a..876591d 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -26,6 +26,7 @@ from library.common_gui import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, + check_if_model_exist, ) from library.tensorboard_gui import ( gradio_tensorboard, @@ -110,7 +111,8 @@ def save_configuration( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -222,15 +224,16 @@ def open_configuration( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): # Get list of function parameters and values parameters = list(locals().items()) - + ask_for_file = True if ask_for_file.get('label') == 'True' else False original_file_path = file_path - + if ask_for_file: file_path = get_file_path(file_path) @@ -316,7 +319,8 @@ def train_model( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -350,6 +354,9 @@ def train_model( if not os.path.exists(output_dir): os.makedirs(output_dir) + if check_if_model_exist(output_name, output_dir, save_model_as): + return + # Get a list of all subfolders in train_data_dir subfolders = [ f @@ -761,7 +768,8 @@ def ti_tab( bucket_reso_steps, caption_dropout_every_n_epochs, caption_dropout_rate, - noise_offset,additional_parameters, + noise_offset, + additional_parameters, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -866,7 +874,8 @@ def ti_tab( sample_every_n_steps, sample_every_n_epochs, sample_sampler, - sample_prompts,additional_parameters, + sample_prompts, + additional_parameters, ] button_open_config.click( @@ -875,7 +884,7 @@ def ti_tab( outputs=[config_file_name] + settings_list, show_progress=False, ) - + button_load_config.click( open_configuration, inputs=[dummy_db_false, config_file_name] + settings_list, diff --git a/train_db.py b/train_db.py index a302117..81aeda1 100644 --- a/train_db.py +++ b/train_db.py @@ -7,6 +7,7 @@ import argparse import itertools import math import os +import toml from tqdm import tqdm import torch @@ -17,348 +18,392 @@ from diffusers import DDPMScheduler import library.train_util as train_util import library.config_util as config_util from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, + ConfigSanitizer, + BlueprintGenerator, ) def collate_fn(examples): - return examples[0] + return examples[0] def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, False) + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, False) - cache_latents = args.cache_latents + cache_latents = args.cache_latents - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenizer = train_util.load_tokenizer(args) - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) - else: - user_config = { - "datasets": [{ - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) - }] - } + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - if args.no_token_padding: - train_dataset_group.disable_token_padding() + if args.no_token_padding: + train_dataset_group.disable_token_padding() - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return - if cache_latents: - assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - # acceleratorを準備する - print("prepare accelerator") + # acceleratorを準備する + print("prepare accelerator") - if args.gradient_accumulation_steps > 1: - print(f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong") - print( - f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です") + if args.gradient_accumulation_steps > 1: + print( + f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" + ) + print( + f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" + ) - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator, unwrap_model = train_util.prepare_accelerator(args) - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - # モデルを読み込む - text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) - # verify load/save model formats - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # 学習を準備する:モデルを適切な状態にする - train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 - unet.requires_grad_(True) # 念のため追加 - text_encoder.requires_grad_(train_text_encoder) - if not train_text_encoder: - print("Text Encoder is not trained.") - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") - if train_text_encoder: - trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters())) - else: - trainable_params = unet.parameters() - - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - - if args.stop_text_encoder_training is None: - args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end - - # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - - if not train_text_encoder: - text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000, clip_sample=False) - - if accelerator.is_main_process: - accelerator.init_trackers("dreambooth") - - loss_list = [] - loss_total = 0.0 - for epoch in range(num_train_epochs): - print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - - # 指定したステップ数までText Encoderを学習する:epoch最初の状態 - unet.train() - # train==True is required to enable gradient_checkpointing - if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: - text_encoder.train() - - for step, batch in enumerate(train_dataloader): - # 指定したステップ数でText Encoderの学習を止める - if global_step == args.stop_text_encoder_training: - print(f"stop text encoder training at step {global_step}") - if not args.gradient_checkpointing: - text_encoder.train(False) - text_encoder.requires_grad_(False) - - with accelerator.accumulate(unet): + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] + train_dataset_group.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + # 学習を準備する:モデルを適切な状態にする + train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 + unet.requires_grad_(True) # 念のため追加 + text_encoder.requires_grad_(train_text_encoder) + if not train_text_encoder: + print("Text Encoder is not trained.") - # Get the text embedding for conditioning - with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + if train_text_encoder: + trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + trainable_params = unet.parameters() - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + _, _, optimizer = train_util.get_optimizer(args, trainable_params) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collate_fn, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - if train_text_encoder: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) - else: - params_to_clip = unet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + # acceleratorがなんかよろしくやってくれるらしい + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error - train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value - logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] - accelerator.log(logs, step=global_step) + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - if global_step >= args.max_train_steps: - break + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} - accelerator.log(logs, step=epoch+1) + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 - accelerator.wait_for_everyone() + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) - if args.save_every_n_epochs is not None: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth") - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + loss_list = [] + loss_total = 0.0 + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset_group.set_current_epoch(epoch + 1) - is_main_process = accelerator.is_main_process - if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 + unet.train() + # train==True is required to enable gradient_checkpointing + if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: + text_encoder.train() - accelerator.end_training() + for step, batch in enumerate(train_dataloader): + # 指定したステップ数でText Encoderの学習を止める + if global_step == args.stop_text_encoder_training: + print(f"stop text encoder training at step {global_step}") + if not args.gradient_checkpointing: + text_encoder.train(False) + text_encoder.requires_grad_(False) - if args.save_state: - train_util.save_state_on_train_end(args, accelerator) + with accelerator.accumulate(unet): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] - del accelerator # この後メモリを使うのでこれは消す + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) - if is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, global_step, text_encoder, unet, vae) - print("model saved.") + # Get the text embedding for conditioning + with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + if train_text_encoder: + params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end( + args, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + print("model saved.") -if __name__ == '__main__': - parser = argparse.ArgumentParser() +if __name__ == "__main__": + parser = argparse.ArgumentParser() - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, False, True) - train_util.add_training_arguments(parser, True) - train_util.add_sd_saving_arguments(parser) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, False, True) + train_util.add_training_arguments(parser, True) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) - parser.add_argument("--no_token_padding", action="store_true", - help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") - parser.add_argument("--stop_text_encoder_training", type=int, default=None, - help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない") + parser.add_argument( + "--no_token_padding", + action="store_true", + help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)", + ) + parser.add_argument( + "--stop_text_encoder_training", + type=int, + default=None, + help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", + ) - args = parser.parse_args() - train(args) + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_network.py b/train_network.py index 5aa8af4..7f910df 100644 --- a/train_network.py +++ b/train_network.py @@ -7,6 +7,7 @@ import os import random import time import json +import toml from tqdm import tqdm import torch @@ -25,635 +26,668 @@ from library.config_util import ( def collate_fn(examples): - return examples[0] + return examples[0] # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): - logs = {"loss/current": current_loss, "loss/average": avr_loss} + logs = {"loss/current": current_loss, "loss/average": avr_loss} - if args.network_train_unet_only: - logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) - else: - logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) - logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder + if args.network_train_unet_only: + logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0]) + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) + else: + logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) + logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. - logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr'] + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - return logs + return logs def train(args): - session_id = random.randint(0, 2**32) - training_started_at = time.time() - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) + session_id = random.randint(0, 2**32) + training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None - use_user_config = args.dataset_config is not None + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None - if args.seed is not None: - set_seed(args.seed) + if args.seed is not None: + set_seed(args.seed) - tokenizer = train_util.load_tokenizer(args) + tokenizer = train_util.load_tokenizer(args) - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) - if use_user_config: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) - else: - if use_dreambooth_method: - print("Use DreamBooth method.") - user_config = { - "datasets": [{ - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) - }] - } + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) else: - print("Train with captions.") - user_config = { - "datasets": [{ - "subsets": [{ - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - }] - }] - } - - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group) - return - if len(train_dataset_group) == 0: - print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)") - return - - if cache_latents: - assert train_dataset_group.is_latent_cacheable( - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # acceleratorを準備する - print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) - - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - - # work on low-ram device - if args.lowram: - text_encoder.to("cuda") - unet.to("cuda") - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # prepare network - import sys - sys.path.append(os.path.dirname(__file__)) - print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) - - net_kwargs = {} - if args.network_args is not None: - for net_arg in args.network_args: - key, value = net_arg.split('=') - net_kwargs[key] = value - - # if a new network is added in future, add if ~ then blocks for each network (;'∀') - network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) - if network is None: - return - - if args.network_weights is not None: - print("load network weights from:", args.network_weights) - network.load_weights(args.network_weights) - - train_unet = not args.network_train_text_encoder_only - train_text_encoder = not args.network_train_unet_only - network.apply_to(text_encoder, unet, train_text_encoder, train_unet) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - network.enable_gradient_checkpointing() # may have no effect - - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") - - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) - optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes) - if is_main_process: - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") - network.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if train_unet and train_text_encoder: - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler) - elif train_unet: - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler) - elif train_text_encoder: - text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, network, optimizer, train_dataloader, lr_scheduler) - else: - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler) - - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) - text_encoder.to(accelerator.device) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() - text_encoder.train() - - # set top parameter requires_grad = True for gradient checkpointing works - if type(text_encoder) == DDP: - text_encoder.module.text_model.embeddings.requires_grad_(True) - else: - text_encoder.text_model.embeddings.requires_grad_(True) - else: - unet.eval() - text_encoder.eval() - - # support DistributedDataParallel - if type(text_encoder) == DDP: - text_encoder = text_encoder.module - unet = unet.module - network = network.module - - network.prepare_grad_etc(text_encoder, unet) - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - # TODO: find a way to handle total batch size when there are multiple datasets - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - if is_main_process: - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - # TODO refactor metadata creation and move to util - metadata = { - "ss_session_id": session_id, # random integer indicating which group of epochs the model came from - "ss_training_started_at": training_started_at, # unix timestamp - "ss_output_name": args.output_name, - "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, - "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset_group.num_train_images, - "ss_num_reg_images": train_dataset_group.num_reg_images, - "ss_num_batches_per_epoch": len(train_dataloader), - "ss_num_epochs": num_train_epochs, - "ss_gradient_checkpointing": args.gradient_checkpointing, - "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, - "ss_max_train_steps": args.max_train_steps, - "ss_lr_warmup_steps": args.lr_warmup_steps, - "ss_lr_scheduler": args.lr_scheduler, - "ss_network_module": args.network_module, - "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim - "ss_network_alpha": args.network_alpha, # some networks may not use this value - "ss_mixed_precision": args.mixed_precision, - "ss_full_fp16": bool(args.full_fp16), - "ss_v2": bool(args.v2), - "ss_clip_skip": args.clip_skip, - "ss_max_token_length": args.max_token_length, - "ss_cache_latents": bool(args.cache_latents), - "ss_seed": args.seed, - "ss_lowram": args.lowram, - "ss_noise_offset": args.noise_offset, - "ss_training_comment": args.training_comment, # will not be updated after training - "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), - "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), - "ss_max_grad_norm": args.max_grad_norm, - "ss_caption_dropout_rate": args.caption_dropout_rate, - "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, - "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, - "ss_face_crop_aug_range": args.face_crop_aug_range, - "ss_prior_loss_weight": args.prior_loss_weight, - } - - if use_user_config: - # save metadata of multiple datasets - # NOTE: pack "ss_datasets" value as json one time - # or should also pack nested collections as json? - datasets_metadata = [] - tag_frequency = {} # merge tag frequency for metadata editor - dataset_dirs_info = {} # merge subset dirs for metadata editor - - for dataset in train_dataset_group.datasets: - is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) - dataset_metadata = { - "is_dreambooth": is_dreambooth_dataset, - "batch_size_per_device": dataset.batch_size, - "num_train_images": dataset.num_train_images, # includes repeating - "num_reg_images": dataset.num_reg_images, - "resolution": (dataset.width, dataset.height), - "enable_bucket": bool(dataset.enable_bucket), - "min_bucket_reso": dataset.min_bucket_reso, - "max_bucket_reso": dataset.max_bucket_reso, - "tag_frequency": dataset.tag_frequency, - "bucket_info": dataset.bucket_info, - } - - subsets_metadata = [] - for subset in dataset.subsets: - subset_metadata = { - "img_count": subset.img_count, - "num_repeats": subset.num_repeats, - "color_aug": bool(subset.color_aug), - "flip_aug": bool(subset.flip_aug), - "random_crop": bool(subset.random_crop), - "shuffle_caption": bool(subset.shuffle_caption), - "keep_tokens": subset.keep_tokens, - } - - image_dir_or_metadata_file = None - if subset.image_dir: - image_dir = os.path.basename(subset.image_dir) - subset_metadata["image_dir"] = image_dir - image_dir_or_metadata_file = image_dir - - if is_dreambooth_dataset: - subset_metadata["class_tokens"] = subset.class_tokens - subset_metadata["is_reg"] = subset.is_reg - if subset.is_reg: - image_dir_or_metadata_file = None # not merging reg dataset + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } else: - metadata_file = os.path.basename(subset.metadata_file) - subset_metadata["metadata_file"] = metadata_file - image_dir_or_metadata_file = metadata_file # may overwrite + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - subsets_metadata.append(subset_metadata) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - # merge dataset dir: not reg subset only - # TODO update additional-network extension to show detailed dataset config from metadata - if image_dir_or_metadata_file is not None: - # datasets may have a certain dir multiple times - v = image_dir_or_metadata_file - i = 2 - while v in dataset_dirs_info: - v = image_dir_or_metadata_file + f" ({i})" - i += 1 - image_dir_or_metadata_file = v + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return - dataset_dirs_info[image_dir_or_metadata_file] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count - } + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - dataset_metadata["subsets"] = subsets_metadata - datasets_metadata.append(dataset_metadata) + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process - # merge tag frequency: - for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): - # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える - # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない - # なので、ここで複数datasetの回数を合算してもあまり意味はない - if ds_dir_name in tag_frequency: - continue - tag_frequency[ds_dir_name] = ds_freq_for_dir + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - metadata["ss_datasets"] = json.dumps(datasets_metadata) - metadata["ss_tag_frequency"] = json.dumps(tag_frequency) - metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) - else: - # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir - assert len( - train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - dataset = train_dataset_group.datasets[0] + # work on low-ram device + if args.lowram: + text_encoder.to("cuda") + unet.to("cuda") - dataset_dirs_info = {} - reg_dataset_dirs_info = {} - if use_dreambooth_method: - for subset in dataset.subsets: - info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info - info[os.path.basename(subset.image_dir)] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count - } - else: - for subset in dataset.subsets: - dataset_dirs_info[os.path.basename(subset.metadata_file)] = { - "n_repeats": subset.num_repeats, - "img_count": subset.img_count - } + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - metadata.update({ - "ss_batch_size_per_device": args.train_batch_size, - "ss_total_batch_size": total_batch_size, - "ss_resolution": args.resolution, - "ss_color_aug": bool(args.color_aug), - "ss_flip_aug": bool(args.flip_aug), - "ss_random_crop": bool(args.random_crop), - "ss_shuffle_caption": bool(args.shuffle_caption), - "ss_enable_bucket": bool(dataset.enable_bucket), - "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), - "ss_min_bucket_reso": dataset.min_bucket_reso, - "ss_max_bucket_reso": dataset.max_bucket_reso, - "ss_keep_tokens": args.keep_tokens, - "ss_dataset_dirs": json.dumps(dataset_dirs_info), - "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), - "ss_tag_frequency": json.dumps(dataset.tag_frequency), - "ss_bucket_info": json.dumps(dataset.bucket_info), - }) - - # add extra args - if args.network_args: - metadata["ss_network_args"] = json.dumps(net_kwargs) - # for key, value in net_kwargs.items(): - # metadata["ss_arg_" + key] = value - - # model name and hash - if args.pretrained_model_name_or_path is not None: - sd_model_name = args.pretrained_model_name_or_path - if os.path.exists(sd_model_name): - metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) - metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) - sd_model_name = os.path.basename(sd_model_name) - metadata["ss_sd_model_name"] = sd_model_name - - if args.vae is not None: - vae_name = args.vae - if os.path.exists(vae_name): - metadata["ss_vae_hash"] = train_util.model_hash(vae_name) - metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) - vae_name = os.path.basename(vae_name) - metadata["ss_vae_name"] = vae_name - - metadata = {k: str(v) for k, v in metadata.items()} - - # make minimum metadata for filtering - minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"] - minimum_metadata = {} - for key in minimum_keys: - if key in metadata: - minimum_metadata[key] = metadata[key] - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000, clip_sample=False) - - if accelerator.is_main_process: - accelerator.init_trackers("network_train") - - loss_list = [] - loss_total = 0.0 - for epoch in range(num_train_epochs): - if is_main_process: - print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - - metadata["ss_epoch"] = str(epoch+1) - - network.on_epoch_start(text_encoder, unet) - - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(network): + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] + train_dataset_group.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - with torch.set_grad_enabled(train_text_encoder): - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + # prepare network + import sys - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + sys.path.append(os.path.dirname(__file__)) + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) + if network is None: + return - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + if args.network_weights is not None: + print("load network weights from:", args.network_weights) + network.load_weights(args.network_weights) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + train_unet = not args.network_train_text_encoder_only + train_text_encoder = not args.network_train_unet_only + network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + network.enable_gradient_checkpointing() # may have no effect + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collate_fn, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes) + if is_main_process: + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + network.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if train_unet and train_text_encoder: + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + elif train_unet: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) + elif train_text_encoder: + text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + text_encoder.train() + + # set top parameter requires_grad = True for gradient checkpointing works + if type(text_encoder) == DDP: + text_encoder.module.text_model.embeddings.requires_grad_(True) else: - target = noise + text_encoder.text_model.embeddings.requires_grad_(True) + else: + unet.eval() + text_encoder.eval() - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + # support DistributedDataParallel + if type(text_encoder) == DDP: + text_encoder = text_encoder.module + unet = unet.module + network = network.module - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + network.prepare_grad_etc(text_encoder, unet) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = network.get_trainable_params() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + if is_main_process: + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) + # TODO refactor metadata creation and move to util + metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_reg_images": train_dataset_group.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_gradient_checkpointing": args.gradient_checkpointing, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not use this value + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_cache_latents": bool(args.cache_latents), + "ss_seed": args.seed, + "ss_lowram": args.lowram, + "ss_noise_offset": args.noise_offset, + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_max_grad_norm": args.max_grad_norm, + "ss_caption_dropout_rate": args.caption_dropout_rate, + "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, + "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, + "ss_face_crop_aug_range": args.face_crop_aug_range, + "ss_prior_loss_weight": args.prior_loss_weight, + } - if global_step >= args.max_train_steps: - break + if use_user_config: + # save metadata of multiple datasets + # NOTE: pack "ss_datasets" value as json one time + # or should also pack nested collections as json? + datasets_metadata = [] + tag_frequency = {} # merge tag frequency for metadata editor + dataset_dirs_info = {} # merge subset dirs for metadata editor - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} - accelerator.log(logs, step=epoch+1) + for dataset in train_dataset_group.datasets: + is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) + dataset_metadata = { + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, + } - accelerator.wait_for_everyone() + subsets_metadata = [] + for subset in dataset.subsets: + subset_metadata = { + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, + } - if args.save_every_n_epochs is not None: - model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + image_dir_or_metadata_file = None + if subset.image_dir: + image_dir = os.path.basename(subset.image_dir) + subset_metadata["image_dir"] = image_dir + image_dir_or_metadata_file = image_dir - def save_func(): - ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as + if is_dreambooth_dataset: + subset_metadata["class_tokens"] = subset.class_tokens + subset_metadata["is_reg"] = subset.is_reg + if subset.is_reg: + image_dir_or_metadata_file = None # not merging reg dataset + else: + metadata_file = os.path.basename(subset.metadata_file) + subset_metadata["metadata_file"] = metadata_file + image_dir_or_metadata_file = metadata_file # may overwrite + + subsets_metadata.append(subset_metadata) + + # merge dataset dir: not reg subset only + # TODO update additional-network extension to show detailed dataset config from metadata + if image_dir_or_metadata_file is not None: + # datasets may have a certain dir multiple times + v = image_dir_or_metadata_file + i = 2 + while v in dataset_dirs_info: + v = image_dir_or_metadata_file + f" ({i})" + i += 1 + image_dir_or_metadata_file = v + + dataset_dirs_info[image_dir_or_metadata_file] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + + dataset_metadata["subsets"] = subsets_metadata + datasets_metadata.append(dataset_metadata) + + # merge tag frequency: + for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): + # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える + # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない + # なので、ここで複数datasetの回数を合算してもあまり意味はない + if ds_dir_name in tag_frequency: + continue + tag_frequency[ds_dir_name] = ds_freq_for_dir + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + metadata["ss_tag_frequency"] = json.dumps(tag_frequency) + metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) + else: + # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir + assert ( + len(train_dataset_group.datasets) == 1 + ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" + + dataset = train_dataset_group.datasets[0] + + dataset_dirs_info = {} + reg_dataset_dirs_info = {} + if use_dreambooth_method: + for subset in dataset.subsets: + info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info + info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + else: + for subset in dataset.subsets: + dataset_dirs_info[os.path.basename(subset.metadata_file)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + metadata.update( + { + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), + } + ) + + # add extra args + if args.network_args: + metadata["ss_network_args"] = json.dumps(net_kwargs) + # for key, value in net_kwargs.items(): + # metadata["ss_arg_" + key] = value + + # model name and hash + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + + metadata = {k: str(v) for k, v in metadata.items()} + + # make minimum metadata for filtering + minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"] + minimum_metadata = {} + for key in minimum_keys: + if key in metadata: + minimum_metadata[key] = metadata[key] + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("network_train") + + loss_list = [] + loss_total = 0.0 + for epoch in range(num_train_epochs): + if is_main_process: + print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset_group.set_current_epoch(epoch + 1) + + metadata["ss_epoch"] = str(epoch + 1) + + network.on_epoch_start(text_encoder, unet) + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(network): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + metadata["ss_training_finished_at"] = str(time.time()) + print(f"saving checkpoint: {ckpt_file}") + unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + if is_main_process: + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # end of epoch + + metadata["ss_epoch"] = str(num_train_epochs) + metadata["ss_training_finished_at"] = str(time.time()) + + if is_main_process: + network = unwrap_model(network) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) - metadata["ss_training_finished_at"] = str(time.time()) - print(f"saving checkpoint: {ckpt_file}") - unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - def remove_old_func(old_epoch_no): - old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - if is_main_process: - saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) - if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) - - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) - - # end of epoch - - metadata["ss_epoch"] = str(num_train_epochs) - metadata["ss_training_finished_at"] = str(time.time()) - - if is_main_process: - network = unwrap_model(network) - - accelerator.end_training() - - if args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - - model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - ckpt_name = model_name + '.' + args.save_model_as - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - print(f"save trained model to {ckpt_file}") - network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) - print("model saved.") + print(f"save trained model to {ckpt_file}") + network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + print("model saved.") -if __name__ == '__main__': - parser = argparse.ArgumentParser() +if __name__ == "__main__": + parser = argparse.ArgumentParser() - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True, True) - train_util.add_training_arguments(parser, True) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) - parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") - parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)") + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) - parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") - parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - parser.add_argument("--network_weights", type=str, default=None, - help="pretrained weights for network / 学習するネットワークの初期重み") - parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') - parser.add_argument("--network_dim", type=int, default=None, - help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') - parser.add_argument("--network_alpha", type=float, default=1, - help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)') - parser.add_argument("--network_args", type=str, default=None, nargs='*', - help='additional argmuments for network (key=value) / ネットワークへの追加の引数') - parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") - parser.add_argument("--network_train_text_encoder_only", action="store_true", - help="only training Text Encoder part / Text Encoder関連部分のみ学習する") - parser.add_argument("--training_comment", type=str, default=None, - help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール") + parser.add_argument( + "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)" + ) + parser.add_argument( + "--network_alpha", + type=float, + default=1, + help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)", + ) + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" + ) + parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") + parser.add_argument( + "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + ) + parser.add_argument( + "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" + ) - args = parser.parse_args() - train(args) + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 34b7f09..e4ab7b5 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -3,6 +3,7 @@ import argparse import gc import math import os +import toml from tqdm import tqdm import torch @@ -13,8 +14,8 @@ from diffusers import DDPMScheduler import library.train_util as train_util import library.config_util as config_util from library.config_util import ( - ConfigSanitizer, - BlueprintGenerator, + ConfigSanitizer, + BlueprintGenerator, ) imagenet_templates_small = [ @@ -71,456 +72,500 @@ imagenet_style_templates_small = [ def collate_fn(examples): - return examples[0] + return examples[0] def train(args): - if args.output_name is None: - args.output_name = args.token_string - use_template = args.use_object_template or args.use_style_template + if args.output_name is None: + args.output_name = args.token_string + use_template = args.use_object_template or args.use_style_template - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - cache_latents = args.cache_latents + cache_latents = args.cache_latents - if args.seed is not None: - set_seed(args.seed) + if args.seed is not None: + set_seed(args.seed) - tokenizer = train_util.load_tokenizer(args) + tokenizer = train_util.load_tokenizer(args) - # acceleratorを準備する - print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - # Convert the init_word to token_id - if args.init_word is not None: - init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) - if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: - print( - f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}") - else: - init_token_ids = None - - # add new word to tokenizer, count is num_vectors_per_token - token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] - num_added_tokens = tokenizer.add_tokens(token_strings) - assert num_added_tokens == args.num_vectors_per_token, f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"tokens are added: {token_ids}") - assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" - assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) - - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data - if init_token_ids is not None: - for i, token_id in enumerate(token_ids): - token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - - # load weights - if args.weights is not None: - embeddings = load_weights(args.weights) - assert len(token_ids) == len( - embeddings), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - # print(token_ids, embeddings.size()) - for token_id, embedding in zip(token_ids, embeddings): - token_embeds[token_id] = embedding - # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) - print(f"weighs loaded") - - print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") - - # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored))) - else: - use_dreambooth_method = args.in_json is None - if use_dreambooth_method: - print("Use DreamBooth method.") - user_config = { - "datasets": [{ - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir) - }] - } + # Convert the init_word to token_id + if args.init_word is not None: + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" + ) else: - print("Train with captions.") - user_config = { - "datasets": [{ - "subsets": [{ - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - }] - }] - } + init_token_ids = None - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + # add new word to tokenizer, count is num_vectors_per_token + token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == args.num_vectors_per_token + ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" - # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 - if use_template: - print("use template for training captions. is object: {args.use_object_template}") - templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small - replace_to = " ".join(token_strings) - captions = [] - for tmpl in templates: - captions.append(tmpl.format(replace_to)) - train_dataset_group.add_replacement("", captions) + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"tokens are added: {token_ids}") + assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - if args.num_vectors_per_token > 1: - prompt_replacement = (args.token_string, replace_to) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_token_ids is not None: + for i, token_id in enumerate(token_ids): + token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + + # load weights + if args.weights is not None: + embeddings = load_weights(args.weights) + assert len(token_ids) == len( + embeddings + ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" + # print(token_ids, embeddings.size()) + for token_id, embedding in zip(token_ids, embeddings): + token_embeds[token_id] = embedding + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + print(f"weighs loaded") + + print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) else: - prompt_replacement = None - else: - if args.num_vectors_per_token > 1: - replace_to = " ".join(token_strings) - train_dataset_group.add_replacement(args.token_string, replace_to) - prompt_replacement = (args.token_string, replace_to) - else: - prompt_replacement = None - - if args.debug_dataset: - train_util.debug_dataset(train_dataset_group, show_input_ids=True) - return - if len(train_dataset_group) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") - return - - if cache_latents: - assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset_group.cache_latents(vae) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") - trainable_params = text_encoder.get_input_embeddings().parameters() - _, _, optimizer = train_util.get_optimizer(args, trainable_params) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで - train_dataloader = torch.utils.data.DataLoader( - train_dataset_group, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) - - # 学習ステップ数を計算する - if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - - # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) - - # acceleratorがなんかよろしくやってくれるらしい - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler) - - index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] - # print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() - - # Freeze all parameters except for the token embeddings in text encoder - text_encoder.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) - - unet.requires_grad_(False) - unet.to(accelerator.device, dtype=weight_dtype) - if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - unet.train() - else: - unet.eval() - - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - text_encoder.to(weight_dtype) - - # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000, clip_sample=False) - - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") - - for epoch in range(num_train_epochs): - print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - - text_encoder.train() - - loss_total = 0 - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - # weight_dtype) use float instead of fp16/bf16 because text encoder is float - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) - - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } else: - target = noise + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 + if use_template: + print("use template for training captions. is object: {args.use_object_template}") + templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small + replace_to = " ".join(token_strings) + captions = [] + for tmpl in templates: + captions.append(tmpl.format(replace_to)) + train_dataset_group.add_replacement("", captions) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + else: + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = text_encoder.get_input_embeddings().parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, show_input_ids=True) + return + if len(train_dataset_group) == 0: + print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + return - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - # Let's make sure we don't update any embedding weights besides the newly added token + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() with torch.no_grad(): - unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates] + train_dataset_group.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() - train_util.sample_images(accelerator, args, None, global_step, accelerator.device, - vae, tokenizer, text_encoder, unet, prompt_replacement) + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + trainable_params = text_encoder.get_input_embeddings().parameters() + _, _, optimizer = train_util.get_optimizer(args, trainable_params) - current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value - logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr'] - accelerator.log(logs, step=global_step) + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collate_fn, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) - loss_total += current_loss - avr_loss = loss_total / (step+1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") - if global_step >= args.max_train_steps: - break + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch+1) + # acceleratorがなんかよろしくやってくれるらしい + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) - accelerator.wait_for_everyone() + index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] + # print(len(index_no_updates), torch.sum(index_no_updates)) + orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.requires_grad_(True) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) - if args.save_every_n_epochs is not None: - model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + else: + unet.eval() - def save_func(): - ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + text_encoder.to(weight_dtype) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + train_dataset_group.set_current_epoch(epoch + 1) + + text_encoder.train() + + loss_total = 0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + # weight_dtype) use float instead of fp16/bf16 because text encoder is float + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = text_encoder.get_input_embeddings().parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + index_no_updates + ] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + # end of epoch + + is_main_process = accelerator.is_main_process + if is_main_process: + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"saving checkpoint: {ckpt_file}") + + print(f"save trained model to {ckpt_file}") save_weights(ckpt_file, updated_embs, save_dtype) - - def remove_old_func(old_epoch_no): - old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) - if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) - - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, - vae, tokenizer, text_encoder, unet, prompt_replacement) - - # end of epoch - - is_main_process = accelerator.is_main_process - if is_main_process: - text_encoder = unwrap_model(text_encoder) - - accelerator.end_training() - - if args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - - model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - ckpt_name = model_name + '.' + args.save_model_as - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - print(f"save trained model to {ckpt_file}") - save_weights(ckpt_file, updated_embs, save_dtype) - print("model saved.") + print("model saved.") def save_weights(file, updated_embs, save_dtype): - state_dict = {"emb_params": updated_embs} + state_dict = {"emb_params": updated_embs} - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import save_file - save_file(state_dict, file) - else: - torch.save(state_dict, file) # can be loaded in Web UI + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file) + else: + torch.save(state_dict, file) # can be loaded in Web UI def load_weights(file): - if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file - data = load_file(file) - else: - # compatible to Web UI's file format - data = torch.load(file, map_location='cpu') - if type(data) != dict: - raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file - if 'string_to_param' in data: # textual inversion embeddings - data = data['string_to_param'] - if hasattr(data, '_parameters'): # support old PyTorch? - data = getattr(data, '_parameters') + data = load_file(file) + else: + # compatible to Web UI's file format + data = torch.load(file, map_location="cpu") + if type(data) != dict: + raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") - emb = next(iter(data.values())) - if type(emb) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") + if "string_to_param" in data: # textual inversion embeddings + data = data["string_to_param"] + if hasattr(data, "_parameters"): # support old PyTorch? + data = getattr(data, "_parameters") - if len(emb.size()) == 1: - emb = emb.unsqueeze(0) + emb = next(iter(data.values())) + if type(emb) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") - return emb + if len(emb.size()) == 1: + emb = emb.unsqueeze(0) + + return emb -if __name__ == '__main__': - parser = argparse.ArgumentParser() +if __name__ == "__main__": + parser = argparse.ArgumentParser() - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, True, True, False) - train_util.add_training_arguments(parser, True) - train_util.add_optimizer_arguments(parser) - config_util.add_config_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, False) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) - parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], - help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") + parser.add_argument( + "--save_model_as", + type=str, + default="pt", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", + ) - parser.add_argument("--weights", type=str, default=None, - help="embedding weights to initialize / 学習するネットワークの初期重み") - parser.add_argument("--num_vectors_per_token", type=int, default=1, - help='number of vectors per token / トークンに割り当てるembeddingsの要素数') - parser.add_argument("--token_string", type=str, default=None, - help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること") - parser.add_argument("--init_word", type=str, default=None, - help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") - parser.add_argument("--use_object_template", action='store_true', - help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する") - parser.add_argument("--use_style_template", action='store_true', - help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する") + parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" + ) + parser.add_argument( + "--token_string", + type=str, + default=None, + help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", + ) + parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--use_object_template", + action="store_true", + help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する", + ) + parser.add_argument( + "--use_style_template", + action="store_true", + help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", + ) - args = parser.parse_args() - train(args) + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args)