From 8822eab5a69dcd4b0c14874cfad2b7ee73dfe391 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 28 Mar 2023 11:54:42 -0400 Subject: [PATCH] Merge new sd-scripts updates --- README.md | 10 +- dreambooth_gui.py | 38 +- fine_tune.py | 34 +- finetune_gui.py | 15 +- gen_img_diffusers.py | 4912 +++++++++++++++-------------- library/common_gui.py | 8 +- library/config_util.py | 11 +- library/custom_train_functions.py | 28 +- library/model_util.py | 12 +- library/train_util.py | 179 +- lora_gui.py | 39 +- networks/resize_lora.py | 33 +- textual_inversion_gui.py | 30 +- train_db - Copy.py | 426 +++ train_db.py | 31 +- train_network - Copy.py | 710 +++++ train_network.py | 31 +- train_textual_inversion - Copy.py | 589 ++++ train_textual_inversion.py | 25 +- 19 files changed, 4705 insertions(+), 2456 deletions(-) create mode 100644 train_db - Copy.py create mode 100644 train_network - Copy.py create mode 100644 train_textual_inversion - Copy.py diff --git a/README.md b/README.md index 7fb6433..8b5b7a7 100644 --- a/README.md +++ b/README.md @@ -213,7 +213,15 @@ This will store your a backup file with your current locally installed pip packa ## Change History -* 2023/03/26 (v21.3.6) +* 2023/03/28 (v21.3.6) + - Fix issues when `--persistent_data_loader_workers` is specified. + - The batch members of the bucket are not shuffled. + - `--caption_dropout_every_n_epochs` does not work. + - These issues occurred because the epoch transition was not recognized correctly. Thanks to u-haru for reporting the issue. + - Fix an issue that images are loaded twice in Windows environment. + - Add Min-SNR Weighting strategy. Details are in [#308](https://github.com/kohya-ss/sd-scripts/pull/308). Thank you to AI-Casanova for this great work! + - Add `--min_snr_gamma` option to training scripts, 5 is recommended by paper. + - The Min SNR gamma fiels can be found unser the advanced training tab in all trainers. - Fixed the error while images are ended with capital image extensions. Thanks to @kvzn. https://github.com/bmaltais/kohya_ss/pull/454 * 2023/03/26 (v21.3.5) - Fix for https://github.com/bmaltais/kohya_ss/issues/230 diff --git a/dreambooth_gui.py b/dreambooth_gui.py index fc85f1b..e93f96e 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -108,6 +108,7 @@ def save_configuration( sample_prompts, additional_parameters, vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -216,6 +217,7 @@ def open_configuration( sample_prompts, additional_parameters, vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -306,6 +308,7 @@ def train_model( sample_prompts, additional_parameters, vae_batch_size, + min_snr_gamma, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -335,12 +338,17 @@ def train_model( subfolders = [ f for f in os.listdir(train_data_dir) - if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith('.') + if os.path.isdir(os.path.join(train_data_dir, f)) + and not f.startswith('.') ] # Check if subfolders are present. If not let the user know and return if not subfolders: - print('\033[33mNo subfolders were found in', train_data_dir, ' can\'t train\...033[0m') + print( + '\033[33mNo subfolders were found in', + train_data_dir, + " can't train\...033[0m", + ) return total_steps = 0 @@ -351,7 +359,11 @@ def train_model( try: repeats = int(folder.split('_')[0]) except ValueError: - print('\033[33mSubfolder', folder, 'does not have a proper repeat value, please correct the name or remove it... can\'t train...\033[0m') + print( + '\033[33mSubfolder', + folder, + "does not have a proper repeat value, please correct the name or remove it... can't train...\033[0m", + ) continue # Count the number of images in the folder @@ -359,12 +371,15 @@ def train_model( [ f for f, lower_f in ( - (file, file.lower()) for file in os.listdir(os.path.join(train_data_dir, folder)) + (file, file.lower()) + for file in os.listdir( + os.path.join(train_data_dir, folder) + ) ) if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) ] ) - + if num_images == 0: print(f'{folder} folder contain no images, skipping...') else: @@ -376,7 +391,11 @@ def train_model( print('\033[33mFolder', folder, ':', steps, 'steps\033[0m') if total_steps == 0: - print('\033[33mNo images were found in folder', train_data_dir, '... please rectify!\033[0m') + print( + '\033[33mNo images were found in folder', + train_data_dir, + '... please rectify!\033[0m', + ) return # Print the result @@ -385,7 +404,9 @@ def train_model( if reg_data_dir == '': reg_factor = 1 else: - print('\033[94mRegularisation images are used... Will double the number of steps required...\033[0m') + print( + '\033[94mRegularisation images are used... Will double the number of steps required...\033[0m' + ) reg_factor = 2 # calculate max_train_steps @@ -497,6 +518,7 @@ def train_model( noise_offset=noise_offset, additional_parameters=additional_parameters, vae_batch_size=vae_batch_size, + min_snr_gamma=min_snr_gamma, ) run_cmd += run_cmd_sample( @@ -704,6 +726,7 @@ def dreambooth_tab( noise_offset, additional_parameters, vae_batch_size, + min_snr_gamma, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -805,6 +828,7 @@ def dreambooth_tab( sample_prompts, additional_parameters, vae_batch_size, + min_snr_gamma, ] button_open_config.click( diff --git a/fine_tune.py b/fine_tune.py index 1acf478..637a729 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -6,6 +6,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -19,10 +20,8 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) - - -def collate_fn(examples): - return examples[0] +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def train(args): @@ -64,6 +63,11 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -187,16 +191,21 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -255,13 +264,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + current_epoch.value = epoch + 1 for m in training_models: m.train() loss_total = 0 for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: @@ -302,7 +312,14 @@ def train(args): else: target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + if args.min_snr_gamma: + # do not mean over batch dimension for snr weight + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = loss.mean() # mean over batch dimension + else: + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -396,6 +413,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") diff --git a/finetune_gui.py b/finetune_gui.py index 2f12f7d..b085928 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -104,7 +104,9 @@ def save_configuration( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters,vae_batch_size, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -217,7 +219,9 @@ def open_configuration( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters,vae_batch_size, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -312,7 +316,9 @@ def train_model( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters,vae_batch_size, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): if check_if_model_exist(output_name, output_dir, save_model_as): return @@ -473,6 +479,7 @@ def train_model( noise_offset=noise_offset, additional_parameters=additional_parameters, vae_batch_size=vae_batch_size, + min_snr_gamma=min_snr_gamma, ) run_cmd += run_cmd_sample( @@ -690,6 +697,7 @@ def finetune_tab(): noise_offset, additional_parameters, vae_batch_size, + min_snr_gamma, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -785,6 +793,7 @@ def finetune_tab(): sample_prompts, additional_parameters, vae_batch_size, + min_snr_gamma, ] button_run.click(train_model, inputs=settings_list) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 38bc86e..690d111 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -65,11 +65,22 @@ import diffusers import numpy as np import torch import torchvision -from diffusers import (AutoencoderKL, DDPMScheduler, - EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, - KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, - UNet2DConditionModel, StableDiffusionPipeline) +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + UNet2DConditionModel, + StableDiffusionPipeline, +) from einops import rearrange from torch import einsum from tqdm import tqdm @@ -86,7 +97,7 @@ from tools.original_control_net import ControlNetInfo # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う DEFAULT_TOKEN_LENGTH = 75 @@ -94,7 +105,7 @@ DEFAULT_TOKEN_LENGTH = 75 SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = 'scaled_linear' +SCHEDLER_SCHEDULE = "scaled_linear" # その他の設定 LATENT_CHANNELS = 4 @@ -133,11 +144,12 @@ EPSILON = 1e-6 def exists(val): - return val is not None + return val is not None def default(val, d): - return val if exists(val) else d + return val if exists(val) else d + # flash attention forwards and backwards @@ -145,243 +157,247 @@ def default(val, d): class FlashAttentionFunction(torch.autograd.Function): - @ staticmethod - @ torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """ Algorithm 2 in the paper """ + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - scale = (q.shape[-1] ** -0.5) + scale = q.shape[-1] ** -0.5 - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.) + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) + exp_values = einsum("... i j, ... j d -> ... i d", exp_weights, vc) - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - return o + return o - @ staticmethod - @ torch.no_grad() - def backward(ctx, do): - """ Algorithm 4 in the paper """ + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors - device = q.device + device = q.device - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2) - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + attn_weights = einsum("... i d, ... j d -> ... i j", qc, kc) * scale - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) - exp_attn_weights = torch.exp(attn_weights - mc) + exp_attn_weights = torch.exp(attn_weights - mc) - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.) + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) - p = exp_attn_weights / lc + p = exp_attn_weights / lc - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) + dv_chunk = einsum("... i j, ... i d -> ... j d", p, doc) + dp = einsum("... i d, ... j d -> ... i j", doc, vc) - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) + dq_chunk = einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = einsum("... i j, ... i d -> ... j d", ds, qc) - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) - return dq, dk, dv, None, None, None, None + return dq, dk, dv, None, None, None, None def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention") - flash_func = FlashAttentionFunction + print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention") + flash_func = FlashAttentionFunction - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 - h = self.heads - q = self.to_q(x) + h = self.heads + q = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) + context = context if context is not None else x + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + diffusers.models.attention.CrossAttention.forward = forward_flash_attn def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") + print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) - context = default(context, x) - context = context.to(x.dtype) + context = default(context, x) + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, "hypernetwork") and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - out = rearrange(out, 'b n h d -> b n (h d)', h=h) + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - diffusers.models.attention.CrossAttention.forward = forward_xformers # endregion # region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 @@ -389,1071 +405,1168 @@ def replace_unet_cross_attn_to_xformers(): # Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 -class PipelineLike(): - r""" - Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing - weighting in prompt. - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ +class PipelineLike: + r""" + Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing + weighting in prompt. + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ - def __init__( - self, - device, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - clip_skip: int, - clip_model: CLIPModel, - clip_guidance_scale: float, - clip_image_guidance_scale: float, - vgg16_model: torchvision.models.VGG, - vgg16_guidance_scale: float, - vgg16_layer_no: int, - # safety_checker: StableDiffusionSafetyChecker, - # feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - self.device = device - self.clip_skip = clip_skip + def __init__( + self, + device, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + clip_model: CLIPModel, + clip_guidance_scale: float, + clip_image_guidance_scale: float, + vgg16_model: torchvision.models.VGG, + vgg16_guidance_scale: float, + vgg16_layer_no: int, + # safety_checker: StableDiffusionSafetyChecker, + # feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.device = device + self.clip_skip = clip_skip - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) - self.vae = vae - self.text_encoder = text_encoder - self.tokenizer = tokenizer - self.unet = unet - self.scheduler = scheduler - self.safety_checker = None + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.unet = unet + self.scheduler = scheduler + self.safety_checker = None + + # Textual Inversion + self.token_replacements = {} + + # CLIP guidance + self.clip_guidance_scale = clip_guidance_scale + self.clip_image_guidance_scale = clip_image_guidance_scale + self.clip_model = clip_model + self.normalize = transforms.Normalize(mean=FEATURE_EXTRACTOR_IMAGE_MEAN, std=FEATURE_EXTRACTOR_IMAGE_STD) + self.make_cutouts = MakeCutouts(FEATURE_EXTRACTOR_SIZE) + + # VGG16 guidance + self.vgg16_guidance_scale = vgg16_guidance_scale + if self.vgg16_guidance_scale > 0.0: + return_layers = {f"{vgg16_layer_no}": "feat"} + self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter( + vgg16_model.features, return_layers=return_layers + ) + self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) + + # ControlNet + self.control_nets: List[ControlNetInfo] = [] # Textual Inversion - self.token_replacements = {} + def add_token_replacement(self, target_token_id, rep_token_ids): + self.token_replacements[target_token_id] = rep_token_ids - # CLIP guidance - self.clip_guidance_scale = clip_guidance_scale - self.clip_image_guidance_scale = clip_image_guidance_scale - self.clip_model = clip_model - self.normalize = transforms.Normalize(mean=FEATURE_EXTRACTOR_IMAGE_MEAN, std=FEATURE_EXTRACTOR_IMAGE_STD) - self.make_cutouts = MakeCutouts(FEATURE_EXTRACTOR_SIZE) + def replace_token(self, tokens): + new_tokens = [] + for token in tokens: + if token in self.token_replacements: + new_tokens.extend(self.token_replacements[token]) + else: + new_tokens.append(token) + return new_tokens - # VGG16 guidance - self.vgg16_guidance_scale = vgg16_guidance_scale - if self.vgg16_guidance_scale > 0.0: - return_layers = {f'{vgg16_layer_no}': 'feat'} - self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers) - self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets - # ControlNet - self.control_nets: List[ControlNetInfo] = [] + # region xformersとか使う部分:独自に書き換えるので関係なし - # Textual Inversion - def add_token_replacement(self, target_token_id, rep_token_ids): - self.token_replacements[target_token_id] = rep_token_ids + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) - def replace_token(self, tokens): - new_tokens = [] - for token in tokens: - if token in self.token_replacements: - new_tokens.extend(self.token_replacements[token]) - else: - new_tokens.append(token) - return new_tokens + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) - def set_control_nets(self, ctrl_nets): - self.control_nets = ctrl_nets + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) - # region xformersとか使う部分:独自に書き換えるので関係なし + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) - def enable_xformers_memory_efficient_attention(self): - r""" - Enable memory efficient attention as implemented in xformers. - When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference - time. Speed up at training time is not guaranteed. - Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention - is used. - """ - self.unet.set_use_memory_efficient_attention_xformers(True) + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + # accelerateが必要になるのでとりあえず省略 + raise NotImplementedError("cpu_offload is omitted.") + # if is_accelerate_available(): + # from accelerate import cpu_offload + # else: + # raise ImportError("Please install accelerate via `pip install accelerate`") - def disable_xformers_memory_efficient_attention(self): - r""" - Disable memory efficient attention as implemented in xformers. - """ - self.unet.set_use_memory_efficient_attention_xformers(False) + # device = self.device - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) + # for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + # if cpu_offloaded_model is not None: + # cpu_offload(cpu_offloaded_model, device) - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) + # endregion - def enable_sequential_cpu_offload(self): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - # accelerateが必要になるのでとりあえず省略 - raise NotImplementedError("cpu_offload is omitted.") - # if is_accelerate_available(): - # from accelerate import cpu_offload - # else: - # raise ImportError("Please install accelerate via `pip install accelerate`") - - # device = self.device - - # for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: - # if cpu_offloaded_model is not None: - # cpu_offload(cpu_offloaded_model, device) -# endregion - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_scale: float = None, - strength: float = 0.8, - # num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - vae_batch_size: float = None, - return_latents: bool = False, - # return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: Optional[int] = 1, - img2img_noise=None, - clip_prompts=None, - clip_guide_images=None, - **kwargs, - ): - r""" - Function invoked when calling the pipeline for generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. - `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - is_cancelled_callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. If the function returns - `True`, the inference will be cancelled. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - `None` if cancelled by `is_cancelled_callback`, - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - num_images_per_prompt = 1 # fixed - - if isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - vae_batch_size = batch_size if vae_batch_size is None else ( - int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_prompts=None, + clip_guide_images=None, + **kwargs, ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + num_images_per_prompt = 1 # fixed - # get prompt text embeddings - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") - negative_scale = None - - # get unconditional embeddings for classifier free guidance - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - **kwargs, - ) - - if negative_scale is not None: - _, real_uncond_embeddings, _ = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 - uncond_prompt=[""]*batch_size, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - **kwargs, - ) - - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - - # CLIP guidanceで使用するembeddingsを取得する - if self.clip_guidance_scale > 0: - clip_text_input = prompt_tokens - if clip_text_input.shape[1] > self.tokenizer.model_max_length: - # TODO 75文字を超えたら警告を出す? - print("trim text input", clip_text_input.shape) - clip_text_input = torch.cat([clip_text_input[:, :self.tokenizer.model_max_length-1], - clip_text_input[:, -1].unsqueeze(1)], dim=1) - print("trimmed", clip_text_input.shape) - - for i, clip_prompt in enumerate(clip_prompts): - if clip_prompt is not None: # clip_promptがあれば上書きする - clip_text_input[i] = self.tokenizer(clip_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, - truncation=True, return_tensors="pt",).input_ids.to(self.device) - - text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) - text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK - - if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets: - if isinstance(clip_guide_images, PIL.Image.Image): - clip_guide_images = [clip_guide_images] - - if self.clip_image_guidance_scale > 0: - clip_guide_images = [preprocess_guide_image(im) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images, dim=0) - - clip_guide_images = self.normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) - image_embeddings_clip = self.clip_model.get_image_features(clip_guide_images) - image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) - if len(image_embeddings_clip) == 1: - image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1)) - elif self.vgg16_guidance_scale > 0: - size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?) - clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images, dim=0) - - clip_guide_images = self.vgg16_normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) - image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat'] - if len(image_embeddings_vgg16) == 1: - image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1)) - else: - # ControlNetのhintにguide imageを流用する - # 前処理はControlNet側で行う - pass - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps, self.device) - - latents_dtype = text_embeddings.dtype - init_latents_orig = None - mask = None - - if init_image is None: - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8,) - - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype,).to(self.device) + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) else: - latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype,) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - timesteps = self.scheduler.timesteps.to(self.device) + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - else: - # image to tensor - if isinstance(init_image, PIL.Image.Image): - init_image = [init_image] - if isinstance(init_image[0], PIL.Image.Image): - init_image = [preprocess_image(im) for im in init_image] - init_image = torch.cat(init_image) - if isinstance(init_image, list): - init_image = torch.stack(init_image) + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - # mask image to tensor - if mask_image is not None: - if isinstance(mask_image, PIL.Image.Image): - mask_image = [mask_image] - if isinstance(mask_image[0], PIL.Image.Image): - mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - # encode the init image into latents and scale the latents - init_image = init_image.to(device=self.device, dtype=latents_dtype) - if init_image.size()[2:] == (height // 8, width // 8): - init_latents = init_image - else: - if vae_batch_size >= batch_size: - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - init_latents = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): - init_latent_dist = self.vae.encode(init_image[i:i + vae_batch_size] - if vae_batch_size > 1 else init_image[i].unsqueeze(0)).latent_dist - init_latents.append(init_latent_dist.sample(generator=generator)) - init_latents = torch.cat(init_latents) + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) - init_latents = 0.18215 * init_latents + # get prompt text embeddings - if len(init_latents) == 1: - init_latents = init_latents.repeat((batch_size, 1, 1, 1)) - init_latents_orig = init_latents + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 - # preprocess mask - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=latents_dtype) - if len(mask) == 1: - mask = mask.repeat((batch_size, 1, 1, 1)) + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) + text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) - # add noise to latents using the timesteps - latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(self.device) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 - - if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) - - for i, t in enumerate(tqdm(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - if self.control_nets: - noise_pred = original_control_net.call_unet_and_control_net( - i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample - else: - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - if negative_scale is None: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - else: - noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(num_latent_input) # uncond is real uncond - noise_pred = noise_pred_uncond + guidance_scale * \ - (noise_pred_text - noise_pred_uncond) - negative_scale * (noise_pred_negative - noise_pred_uncond) - - # perform clip guidance - if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: - text_embeddings_for_guidance = (text_embeddings.chunk(num_latent_input)[ - 1] if do_classifier_free_guidance else text_embeddings) + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + # CLIP guidanceで使用するembeddingsを取得する if self.clip_guidance_scale > 0: - noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred, - text_embeddings_clip, self.clip_guidance_scale, NUM_CUTOUTS, USE_CUTOUTS,) - if self.clip_image_guidance_scale > 0 and clip_guide_images is not None: - noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred, - image_embeddings_clip, self.clip_image_guidance_scale, NUM_CUTOUTS, USE_CUTOUTS,) - if self.vgg16_guidance_scale > 0 and clip_guide_images is not None: - noise_pred, latents = self.cond_fn_vgg16(latents, t, i, text_embeddings_for_guidance, noise_pred, - image_embeddings_vgg16, self.vgg16_guidance_scale) + clip_text_input = prompt_tokens + if clip_text_input.shape[1] > self.tokenizer.model_max_length: + # TODO 75文字を超えたら警告を出す? + print("trim text input", clip_text_input.shape) + clip_text_input = torch.cat( + [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 + ) + print("trimmed", clip_text_input.shape) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + for i, clip_prompt in enumerate(clip_prompts): + if clip_prompt is not None: # clip_promptがあれば上書きする + clip_text_input[i] = self.tokenizer( + clip_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.to(self.device) - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) + text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) + text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None + if ( + self.clip_image_guidance_scale > 0 + or self.vgg16_guidance_scale > 0 + and clip_guide_images is not None + or self.control_nets + ): + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] - if return_latents: - return (latents, False) + if self.clip_image_guidance_scale > 0: + clip_guide_images = [preprocess_guide_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images, dim=0) - latents = 1 / 0.18215 * latents - if vae_batch_size >= batch_size: - image = self.vae.decode(latents).sample - else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - images = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): - images.append(self.vae.decode(latents[i:i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample) - image = torch.cat(images) + clip_guide_images = self.normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) + image_embeddings_clip = self.clip_model.get_image_features(clip_guide_images) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + if len(image_embeddings_clip) == 1: + image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1)) + elif self.vgg16_guidance_scale > 0: + size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?) + clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images, dim=0) - image = (image / 2 + 0.5).clamp(0, 1) + clip_guide_images = self.vgg16_normalize(clip_guide_images).to(self.device).to(text_embeddings.dtype) + image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)["feat"] + if len(image_embeddings_vgg16) == 1: + image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1)) + else: + # ControlNetのhintにguide imageを流用する + # 前処理はControlNet側で行う + pass - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) - image, has_nsfw_concept = self.safety_checker( - images=image, - clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype), - ) - else: - has_nsfw_concept = None + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None - if output_type == "pil": - # image = self.numpy_to_pil(image) - image = (image * 255).round().astype("uint8") - image = [Image.fromarray(im) for im in image] + if init_image is None: + # get the initial random noise unless the user supplied it - # if not return_dict: - return (image, has_nsfw_concept) + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) - # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) - def text2img( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 512, - width: int = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): - r""" - Function for text-to-image generation. - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - latents=latents, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + init_latent_dist = self.vae.encode( + init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = 0.18215 * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if self.control_nets: + noise_pred = original_control_net.call_unet_and_control_net( + i, + num_latent_input, + self.unet, + self.control_nets, + guided_hints, + i / len(timesteps), + latent_model_input, + t, + text_embeddings, + ).sample + else: + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # perform clip guidance + if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: + text_embeddings_for_guidance = ( + text_embeddings.chunk(num_latent_input)[1] if do_classifier_free_guidance else text_embeddings + ) + + if self.clip_guidance_scale > 0: + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + text_embeddings_clip, + self.clip_guidance_scale, + NUM_CUTOUTS, + USE_CUTOUTS, + ) + if self.clip_image_guidance_scale > 0 and clip_guide_images is not None: + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + image_embeddings_clip, + self.clip_image_guidance_scale, + NUM_CUTOUTS, + USE_CUTOUTS, + ) + if self.vgg16_guidance_scale > 0 and clip_guide_images is not None: + noise_pred, latents = self.cond_fn_vgg16( + latents, t, i, text_embeddings_for_guidance, noise_pred, image_embeddings_vgg16, self.vgg16_guidance_scale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return (latents, False) + + latents = 1 / 0.18215 * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode(latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype), + ) + else: + has_nsfw_concept = None + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + # if not return_dict: + return (image, has_nsfw_concept) + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, - ) + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) - def img2img( - self, - init_image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): - r""" - Function for image-to-image generation. - Args: - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. - `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The - number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added - noise will be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - init_image=init_image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, + def img2img( + self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, - ) + ): + r""" + Function for image-to-image generation. + Args: + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) - def inpaint( - self, - init_image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, - ): - r""" - Function for inpaint. - Args: - init_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more - noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. - num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - return self.__call__( - prompt=prompt, - negative_prompt=negative_prompt, - init_image=init_image, - mask_image=mask_image, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - strength=strength, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - max_embeddings_multiples=max_embeddings_multiples, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, + def inpaint( + self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, **kwargs, - ) + ): + r""" + Function for inpaint. + Args: + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + init_image=init_image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + **kwargs, + ) - # CLIP guidance StableDiffusion - # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py + # CLIP guidance StableDiffusion + # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py - # バッチを分解して1件ずつ処理する - def cond_fn(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings_clip, clip_guidance_scale, - num_cutouts, use_cutouts=True, ): - if len(latents) == 1: - return self.cond_fn1(latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings_clip, clip_guidance_scale, - num_cutouts, use_cutouts) + # バッチを分解して1件ずつ処理する + def cond_fn( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + if len(latents) == 1: + return self.cond_fn1( + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts, + ) - noise_pred = [] - cond_latents = [] - for i in range(len(latents)): - lat1 = latents[i].unsqueeze(0) - tem1 = text_embeddings[i].unsqueeze(0) - npo1 = noise_pred_original[i].unsqueeze(0) - gem1 = guide_embeddings_clip[i].unsqueeze(0) - npr1, cla1 = self.cond_fn1(lat1, timestep, index, tem1, npo1, gem1, clip_guidance_scale, num_cutouts, use_cutouts) - noise_pred.append(npr1) - cond_latents.append(cla1) + noise_pred = [] + cond_latents = [] + for i in range(len(latents)): + lat1 = latents[i].unsqueeze(0) + tem1 = text_embeddings[i].unsqueeze(0) + npo1 = noise_pred_original[i].unsqueeze(0) + gem1 = guide_embeddings_clip[i].unsqueeze(0) + npr1, cla1 = self.cond_fn1(lat1, timestep, index, tem1, npo1, gem1, clip_guidance_scale, num_cutouts, use_cutouts) + noise_pred.append(npr1) + cond_latents.append(cla1) - noise_pred = torch.cat(noise_pred) - cond_latents = torch.cat(cond_latents) - return noise_pred, cond_latents + noise_pred = torch.cat(noise_pred) + cond_latents = torch.cat(cond_latents) + return noise_pred, cond_latents - @torch.enable_grad() - def cond_fn1(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings_clip, clip_guidance_scale, - num_cutouts, use_cutouts=True, ): - latents = latents.detach().requires_grad_() + @torch.enable_grad() + def cond_fn1( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + guide_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + latents = latents.detach().requires_grad_() - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latents / ((sigma**2 + 1) ** 0.5) - else: - latent_model_input = latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents - # predict the noise residual - noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): - alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - beta_prod_t = 1 - alpha_prod_t - # compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) - fac = torch.sqrt(beta_prod_t) - sample = pred_original_sample * (fac) + latents * (1 - fac) - elif isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - sample = latents - sigma * noise_pred - else: - raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") - sample = 1 / 0.18215 * sample - image = self.vae.decode(sample).sample - image = (image / 2 + 0.5).clamp(0, 1) + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) - if use_cutouts: - image = self.make_cutouts(image, num_cutouts) - else: - image = transforms.Resize(FEATURE_EXTRACTOR_SIZE)(image) - image = self.normalize(image).to(latents.dtype) + if use_cutouts: + image = self.make_cutouts(image, num_cutouts) + else: + image = transforms.Resize(FEATURE_EXTRACTOR_SIZE)(image) + image = self.normalize(image).to(latents.dtype) - image_embeddings_clip = self.clip_model.get_image_features(image) - image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + image_embeddings_clip = self.clip_model.get_image_features(image) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) - if use_cutouts: - dists = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip) - dists = dists.view([num_cutouts, sample.shape[0], -1]) - loss = dists.sum(2).mean(0).sum() * clip_guidance_scale - else: - # バッチサイズが複数だと正しく動くかわからない - loss = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip).mean() * clip_guidance_scale + if use_cutouts: + dists = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip) + dists = dists.view([num_cutouts, sample.shape[0], -1]) + loss = dists.sum(2).mean(0).sum() * clip_guidance_scale + else: + # バッチサイズが複数だと正しく動くかわからない + loss = spherical_dist_loss(image_embeddings_clip, guide_embeddings_clip).mean() * clip_guidance_scale - grads = -torch.autograd.grad(loss, latents)[0] + grads = -torch.autograd.grad(loss, latents)[0] - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents.detach() + grads * (sigma**2) - noise_pred = noise_pred_original - else: - noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads - return noise_pred, latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents - # バッチを分解して一件ずつ処理する - def cond_fn_vgg16(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): - if len(latents) == 1: - return self.cond_fn_vgg16_b1(latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale) + # バッチを分解して一件ずつ処理する + def cond_fn_vgg16(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): + if len(latents) == 1: + return self.cond_fn_vgg16_b1( + latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale + ) - noise_pred = [] - cond_latents = [] - for i in range(len(latents)): - lat1 = latents[i].unsqueeze(0) - tem1 = text_embeddings[i].unsqueeze(0) - npo1 = noise_pred_original[i].unsqueeze(0) - gem1 = guide_embeddings[i].unsqueeze(0) - npr1, cla1 = self.cond_fn_vgg16_b1(lat1, timestep, index, tem1, npo1, gem1, guidance_scale) - noise_pred.append(npr1) - cond_latents.append(cla1) + noise_pred = [] + cond_latents = [] + for i in range(len(latents)): + lat1 = latents[i].unsqueeze(0) + tem1 = text_embeddings[i].unsqueeze(0) + npo1 = noise_pred_original[i].unsqueeze(0) + gem1 = guide_embeddings[i].unsqueeze(0) + npr1, cla1 = self.cond_fn_vgg16_b1(lat1, timestep, index, tem1, npo1, gem1, guidance_scale) + noise_pred.append(npr1) + cond_latents.append(cla1) - noise_pred = torch.cat(noise_pred) - cond_latents = torch.cat(cond_latents) - return noise_pred, cond_latents + noise_pred = torch.cat(noise_pred) + cond_latents = torch.cat(cond_latents) + return noise_pred, cond_latents - # 1件だけ処理する - @torch.enable_grad() - def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): - latents = latents.detach().requires_grad_() + # 1件だけ処理する + @torch.enable_grad() + def cond_fn_vgg16_b1(self, latents, timestep, index, text_embeddings, noise_pred_original, guide_embeddings, guidance_scale): + latents = latents.detach().requires_grad_() - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latents / ((sigma**2 + 1) ** 0.5) - else: - latent_model_input = latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents - # predict the noise residual - noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample - if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): - alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - beta_prod_t = 1 - alpha_prod_t - # compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) - fac = torch.sqrt(beta_prod_t) - sample = pred_original_sample * (fac) + latents * (1 - fac) - elif isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[index] - sample = latents - sigma * noise_pred - else: - raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") - sample = 1 / 0.18215 * sample - image = self.vae.decode(sample).sample - image = (image / 2 + 0.5).clamp(0, 1) - image = transforms.Resize((image.shape[-2] // VGG16_INPUT_RESIZE_DIV, image.shape[-1] // VGG16_INPUT_RESIZE_DIV))(image) - image = self.vgg16_normalize(image).to(latents.dtype) + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) + image = transforms.Resize((image.shape[-2] // VGG16_INPUT_RESIZE_DIV, image.shape[-1] // VGG16_INPUT_RESIZE_DIV))(image) + image = self.vgg16_normalize(image).to(latents.dtype) - image_embeddings = self.vgg16_feat_model(image)['feat'] + image_embeddings = self.vgg16_feat_model(image)["feat"] - # バッチサイズが複数だと正しく動くかわからない - loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので + # バッチサイズが複数だと正しく動くかわからない + loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので - grads = -torch.autograd.grad(loss, latents)[0] - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents.detach() + grads * (sigma**2) - noise_pred = noise_pred_original - else: - noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads - return noise_pred, latents + grads = -torch.autograd.grad(loss, latents)[0] + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents class MakeCutouts(torch.nn.Module): - def __init__(self, cut_size, cut_power=1.0): - super().__init__() + def __init__(self, cut_size, cut_power=1.0): + super().__init__() - self.cut_size = cut_size - self.cut_power = cut_power + self.cut_size = cut_size + self.cut_power = cut_power - def forward(self, pixel_values, num_cutouts): - sideY, sideX = pixel_values.shape[2:4] - max_size = min(sideX, sideY) - min_size = min(sideX, sideY, self.cut_size) - cutouts = [] - for _ in range(num_cutouts): - size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) - offsetx = torch.randint(0, sideX - size + 1, ()) - offsety = torch.randint(0, sideY - size + 1, ()) - cutout = pixel_values[:, :, offsety: offsety + size, offsetx: offsetx + size] - cutouts.append(torch.nn.functional.adaptive_avg_pool2d(cutout, self.cut_size)) - return torch.cat(cutouts) + def forward(self, pixel_values, num_cutouts): + sideY, sideX = pixel_values.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(num_cutouts): + size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size] + cutouts.append(torch.nn.functional.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) def spherical_dist_loss(x, y): - x = torch.nn.functional.normalize(x, dim=-1) - y = torch.nn.functional.normalize(y, dim=-1) - return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + x = torch.nn.functional.normalize(x, dim=-1) + y = torch.nn.functional.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) re_attention = re.compile( @@ -1477,151 +1590,151 @@ re_attention = re.compile( def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ - res = [] - round_brackets = [] - square_brackets = [] + res = [] + round_brackets = [] + square_brackets = [] - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) - if len(res) == 0: - res = [["", 1.0]] + if len(res) == 0: + res = [["", 1.0]] - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 - return res + return res def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - # tokenize and discard the starting and the ending token - token = pipe.tokenizer(word).input_ids[1:-1] + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] - token = pipe.replace_token(token) + token = pipe.replace_token(token) - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2): min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] - return tokens, weights + return tokens, weights def get_unweighted_text_embeddings( @@ -1633,56 +1746,56 @@ def get_unweighted_text_embeddings( pad: int, no_boseos_middle: Optional[bool] = True, ): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2): (i + 1) * (chunk_length - 2) + 2].clone() + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos - if clip_skip is None or clip_skip == 1: - text_embedding = pipe.text_encoder(text_input_chunk)[0] - else: - enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out['hidden_states'][-clip_skip] - text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) + if clip_skip is None or clip_skip == 1: + text_embedding = pipe.text_encoder(text_input_chunk)[0] + else: + enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - if clip_skip is None or clip_skip == 1: - text_embeddings = pipe.text_encoder(text_input)[0] + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) else: - enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out['hidden_states'][-clip_skip] - text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) - return text_embeddings + if clip_skip is None or clip_skip == 1: + text_embeddings = pipe.text_encoder(text_input)[0] + else: + enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) + return text_embeddings def get_weighted_text_embeddings( @@ -1696,84 +1809,69 @@ def get_weighted_text_embeddings( clip_skip=None, **kwargs, ): - r""" - Prompts can be assigned with local weights using brackets. For example, - prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', - and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. - Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. - Args: - pipe (`DiffusionPipeline`): - Pipe to provide access to the tokenizer and the text encoder. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - uncond_prompt (`str` or `List[str]`): - The unconditional prompt or prompts for guide the image generation. If unconditional prompt - is provided, the embeddings of prompt and uncond_prompt are concatenated. - max_embeddings_multiples (`int`, *optional*, defaults to `1`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - no_boseos_middle (`bool`, *optional*, defaults to `False`): - If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and - ending token in each of the chunk in the middle. - skip_parsing (`bool`, *optional*, defaults to `False`): - Skip the parsing of brackets. - skip_weighting (`bool`, *optional*, defaults to `False`): - Skip the weighting. When the parsing is skipped, it is forced True. - """ - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + Args: + pipe (`DiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `1`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) - else: - prompt_tokens = [ - token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids - ] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [ - token[1:-1] - for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids - ] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + max_length = max(max_length, max([len(token) for token in uncond_tokens])) - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = pipe.tokenizer.bos_token_id - eos = pipe.tokenizer.eos_token_id - pad = pipe.tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + pad = pipe.tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, max_length, bos, eos, @@ -1781,86 +1879,100 @@ def get_weighted_text_embeddings( no_boseos_middle=no_boseos_middle, chunk_length=pipe.tokenizer.model_max_length, ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) - # get the embeddings - text_embeddings = get_unweighted_text_embeddings( - pipe, - prompt_tokens, - pipe.tokenizer.model_max_length, - clip_skip, - eos, pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) - if uncond_prompt is not None: - uncond_embeddings = get_unweighted_text_embeddings( + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( pipe, - uncond_tokens, + prompt_tokens, pipe.tokenizer.model_max_length, clip_skip, - eos, pad, + eos, + pad, no_boseos_middle=no_boseos_middle, ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - # →全体でいいんじゃないかな - if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - if uncond_prompt is not None: - return text_embeddings, uncond_embeddings, prompt_tokens - return text_embeddings, None, prompt_tokens + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings, prompt_tokens + return text_embeddings, None, prompt_tokens def preprocess_guide_image(image): - image = image.resize(FEATURE_EXTRACTOR_SIZE, resample=Image.NEAREST) # cond_fnと合わせる - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) # nchw - image = torch.from_numpy(image) - return image # 0 to 1 + image = image.resize(FEATURE_EXTRACTOR_SIZE, resample=Image.NEAREST) # cond_fnと合わせる + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 # VGG16の入力は任意サイズでよいので入力画像を適宜リサイズする def preprocess_vgg16_guide_image(image, size): - image = image.resize(size, resample=Image.NEAREST) # cond_fnと合わせる - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) # nchw - image = torch.from_numpy(image) - return image # 0 to 1 + image = image.resize(size, resample=Image.NEAREST) # cond_fnと合わせる + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask # endregion @@ -1873,924 +1985,1086 @@ def preprocess_mask(mask): class BatchDataBase(NamedTuple): - # バッチ分割が必要ないデータ - step: int - prompt: str - negative_prompt: str - seed: int - init_image: Any - mask_image: Any - clip_prompt: str - guide_image: Any + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any class BatchDataExt(NamedTuple): - # バッチ分割が必要なデータ - width: int - height: int - steps: int - scale: float - negative_scale: float - strength: float - network_muls: Tuple[float] + # バッチ分割が必要なデータ + width: int + height: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] class BatchData(NamedTuple): - return_latents: bool - base: BatchDataBase - ext: BatchDataExt + return_latents: bool + base: BatchDataBase + ext: BatchDataExt def main(args): - if args.fp16: - dtype = torch.float16 - elif args.bf16: - dtype = torch.bfloat16 - else: - dtype = torch.float32 - - highres_fix = args.highres_fix_scale is not None - assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - - # モデルを読み込む - if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う - files = glob.glob(args.ckpt) - if len(files) == 1: - args.ckpt = files[0] - - use_stable_diffusion_format = os.path.isfile(args.ckpt) - if use_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) - else: - print("load Diffusers pretrained models") - loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) - text_encoder = loading_pipe.text_encoder - vae = loading_pipe.vae - unet = loading_pipe.unet - tokenizer = loading_pipe.tokenizer - del loading_pipe - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") - - # # 置換するCLIPを読み込む - # if args.replace_clip_l14_336: - # text_encoder = load_clip_l14_336(dtype) - # print(f"large clip {CLIP_ID_L14_336} is loaded") - - if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: - print("prepare clip model") - clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) - else: - clip_model = None - - if args.vgg16_guidance_scale > 0.0: - print("prepare resnet model") - vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) - else: - vgg16_model = None - - # xformers、Hypernetwork対応 - if not args.diffusers_xformers: - replace_unet_modules(unet, not args.xformers, args.xformers) - - # tokenizerを読み込む - print("loading tokenizer") - if use_stable_diffusion_format: - tokenizer = train_util.load_tokenizer(args) - - # schedulerを用意する - sched_init_args = {} - scheduler_num_noises_per_step = 1 - if args.sampler == "ddim": - scheduler_cls = DDIMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddim - elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddpm - elif args.sampler == "pndm": - scheduler_cls = PNDMScheduler - scheduler_module = diffusers.schedulers.scheduling_pndm - elif args.sampler == 'lms' or args.sampler == 'k_lms': - scheduler_cls = LMSDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_lms_discrete - elif args.sampler == 'euler' or args.sampler == 'k_euler': - scheduler_cls = EulerDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_discrete - elif args.sampler == 'euler_a' or args.sampler == 'k_euler_a': - scheduler_cls = EulerAncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete - elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args['algorithm_type'] = args.sampler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep - elif args.sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep - elif args.sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_heun_discrete - elif args.sampler == 'dpm_2' or args.sampler == 'k_dpm_2': - scheduler_cls = KDPM2DiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete - elif args.sampler == 'dpm_2_a' or args.sampler == 'k_dpm_2_a': - scheduler_cls = KDPM2AncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete - scheduler_num_noises_per_step = 2 - - if args.v_parameterization: - sched_init_args['prediction_type'] = 'v_prediction' - - # samplerの乱数をあらかじめ指定するための処理 - - # replace randn - class NoiseManager: - def __init__(self): - self.sampler_noises = None - self.sampler_noise_index = 0 - - def reset_sampler_noises(self, noises): - self.sampler_noise_index = 0 - self.sampler_noises = noises - - def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) - if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): - noise = self.sampler_noises[self.sampler_noise_index] - if shape != noise.shape: - noise = None - else: - noise = None - - if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") - noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) - - self.sampler_noise_index += 1 - return noise - - class TorchRandReplacer: - def __init__(self, noise_manager): - self.noise_manager = noise_manager - - def __getattr__(self, item): - if item == 'randn': - return self.noise_manager.randn - if hasattr(torch, item): - return getattr(torch, item) - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - noise_manager = NoiseManager() - if scheduler_module is not None: - scheduler_module.torch = TorchRandReplacer(noise_manager) - - scheduler = scheduler_cls(num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, **sched_init_args) - - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - print("set clip_sample to True") - scheduler.config.clip_sample = True - - # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない - - # custom pipelineをコピったやつを生成する - vae.to(dtype).to(device) - text_encoder.to(dtype).to(device) - unet.to(dtype).to(device) - if clip_model is not None: - clip_model.to(dtype).to(device) - if vgg16_model is not None: - vgg16_model.to(dtype).to(device) - - # networkを組み込む - if args.network_module: - networks = [] - network_default_muls = [] - for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) - imported_module = importlib.import_module(network_module) - - network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_default_muls.append(network_mul) - - net_kwargs = {} - if args.network_args and i < len(args.network_args): - network_args = args.network_args[i] - # TODO escape special chars - network_args = network_args.split(";") - for net_arg in network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - if args.network_weights and i < len(args.network_weights): - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) - - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") - - network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs) - else: - raise ValueError("No weight. Weight is required.") - if network is None: - return - - network.apply_to(text_encoder, unet) - - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) - - networks.append(network) - else: - networks = [] - - # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] - if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - - ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) - - if args.opt_channels_last: - print(f"set optimizing: channels last") - text_encoder.to(memory_format=torch.channels_last) - vae.to(memory_format=torch.channels_last) - unet.to(memory_format=torch.channels_last) - if clip_model is not None: - clip_model.to(memory_format=torch.channels_last) - if networks: - for network in networks: - network.to(memory_format=torch.channels_last) - if vgg16_model is not None: - vgg16_model.to(memory_format=torch.channels_last) - - for cn in control_nets: - cn.unet.to(memory_format=torch.channels_last) - cn.net.to(memory_format=torch.channels_last) - - pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip, - clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale, - vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer) - pipe.set_control_nets(control_nets) - print("pipeline is ready.") - - if args.diffusers_xformers: - pipe.enable_xformers_memory_efficient_attention() - - # Textual Inversionを処理する - if args.textual_inversion_embeddings: - token_ids_embeds = [] - for embeds_file in args.textual_inversion_embeddings: - if model_util.is_safetensors(embeds_file): - from safetensors.torch import load_file - data = load_file(embeds_file) - else: - data = torch.load(embeds_file, map_location="cpu") - - embeds = next(iter(data.values())) - if type(embeds) != torch.Tensor: - raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}") - - num_vectors_per_token = embeds.size()[0] - token_string = os.path.splitext(os.path.basename(embeds_file))[0] - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] - - # add new word to tokenizer, count is num_vectors_per_token - num_added_tokens = tokenizer.add_tokens(token_strings) - assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" - - token_ids = tokenizer.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") - assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" - assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" - - if num_vectors_per_token > 1: - pipe.add_token_replacement(token_ids[0], token_ids) - - token_ids_embeds.append((token_ids, embeds)) - - text_encoder.resize_token_embeddings(len(tokenizer)) - token_embeds = text_encoder.get_input_embeddings().weight.data - for token_ids, embeds in token_ids_embeds: - for token_id, embed in zip(token_ids, embeds): - token_embeds[token_id] = embed - - # promptを取得する - if args.from_file is not None: - print(f"reading prompts from {args.from_file}") - with open(args.from_file, "r", encoding="utf-8") as f: - prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] - elif args.prompt is not None: - prompt_list = [args.prompt] - else: - prompt_list = [] - - if args.interactive: - args.n_iter = 1 - - # img2imgの前処理、画像の読み込みなど - def load_images(path): - if os.path.isfile(path): - paths = [path] + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 else: - paths = glob.glob(os.path.join(path, "*.png")) + glob.glob(os.path.join(path, "*.jpg")) + \ - glob.glob(os.path.join(path, "*.jpeg")) + glob.glob(os.path.join(path, "*.webp")) - paths.sort() + dtype = torch.float32 - images = [] - for p in paths: - image = Image.open(p) - if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") - image = image.convert("RGB") - images.append(image) + highres_fix = args.highres_fix_scale is not None + assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - return images + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - def resize_images(imgs, size): - resized = [] - for img in imgs: - r_img = img.resize(size, Image.Resampling.LANCZOS) - if hasattr(img, 'filename'): # filename属性がない場合があるらしい - r_img.filename = img.filename - resized.append(r_img) - return resized + # モデルを読み込む + if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] - if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") - init_images = load_images(args.image_path) - assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") - else: - init_images = None + use_stable_diffusion_format = os.path.isfile(args.ckpt) + if use_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) + else: + print("load Diffusers pretrained models") + loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + text_encoder = loading_pipe.text_encoder + vae = loading_pipe.vae + unet = loading_pipe.unet + tokenizer = loading_pipe.tokenizer + del loading_pipe - if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") - mask_images = load_images(args.mask_path) - assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") - else: - mask_images = None + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, dtype) + print("additional VAE loaded") - # promptがないとき、画像のPngInfoから取得する - if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' meta data") - for img in init_images: - if 'prompt' in img.text: - prompt = img.text['prompt'] - if 'negative-prompt' in img.text: - prompt += " --n " + img.text['negative-prompt'] - prompt_list.append(prompt) + # # 置換するCLIPを読み込む + # if args.replace_clip_l14_336: + # text_encoder = load_clip_l14_336(dtype) + # print(f"large clip {CLIP_ID_L14_336} is loaded") - # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) - l = [] - for im in init_images: - l.extend([im] * args.images_per_prompt) - init_images = l + if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: + print("prepare clip model") + clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) + else: + clip_model = None - if mask_images is not None: - l = [] - for im in mask_images: - l.extend([im] * args.images_per_prompt) - mask_images = l + if args.vgg16_guidance_scale > 0.0: + print("prepare resnet model") + vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) + else: + vgg16_model = None - # 画像サイズにオプション指定があるときはリサイズする - if args.W is not None and args.H is not None: - if init_images is not None: - print(f"resize img2img source images to {args.W}*{args.H}") - init_images = resize_images(init_images, (args.W, args.H)) - if mask_images is not None: - print(f"resize img2img mask images to {args.W}*{args.H}") - mask_images = resize_images(mask_images, (args.W, args.H)) + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + replace_unet_modules(unet, not args.xformers, args.xformers) - if networks and mask_images: - # mask を領域情報として流用する、現在は1枚だけ対応 - # TODO 複数のnetwork classの混在時の考慮 - print("use mask as region") - # import cv2 - # for i in range(3): - # cv2.imshow("msk", np.array(mask_images[0])[:,:,i]) - # cv2.waitKey() - # cv2.destroyAllWindows() - networks[0].__class__.set_regions(networks, np.array(mask_images[0])) - mask_images = None + # tokenizerを読み込む + print("loading tokenizer") + if use_stable_diffusion_format: + tokenizer = train_util.load_tokenizer(args) - prev_image = None # for VGG16 guided - if args.guide_image_path is not None: - print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") - guide_images = [] - for p in args.guide_image_path: - guide_images.extend(load_images(p)) + # schedulerを用意する + sched_init_args = {} + scheduler_num_noises_per_step = 1 + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 - print(f"loaded {len(guide_images)} guide images for guidance") - if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") - guide_images = None - else: - guide_images = None + if args.v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" - # seed指定時はseedを決めておく - if args.seed is not None: - random.seed(args.seed) - predefined_seeds = [random.randint(0, 0x7fffffff) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] - if len(predefined_seeds) == 1: - predefined_seeds[0] = args.seed - else: - predefined_seeds = None + # samplerの乱数をあらかじめ指定するための処理 - # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) - if args.W is None: - args.W = 512 - if args.H is None: - args.H = 512 + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 - # 画像生成のループ - os.makedirs(args.outdir, exist_ok=True) - max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises - for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") - iter_seed = random.randint(0, 0x7fffffff) + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None - # バッチ処理の関数 - def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): - batch_size = len(batch) + if noise == None: + print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) - # highres_fixの処理 - if highres_fix and not highres_1st: - # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す - print("process 1st stage") - batch_1st = [] - for _, base, ext in batch: - width_1st = int(ext.width * args.highres_fix_scale + .5) - height_1st = int(ext.height * args.highres_fix_scale + .5) - width_1st = width_1st - width_1st % 32 - height_1st = height_1st - height_1st % 32 + self.sampler_noise_index += 1 + return noise - ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale, - ext.negative_scale, ext.strength, ext.network_muls) - batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st)) - images_1st = process_batch(batch_1st, True, True) + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager - # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") - if args.highres_fix_latents_upscaling: - org_dtype = images_1st.dtype - if images_1st.dtype == torch.bfloat16: - images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない - images_1st = torch.nn.functional.interpolate( - images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode='bilinear') # , antialias=True) - images_1st = images_1st.to(org_dtype) + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - batch_2nd = [] - for i, (bd, image) in enumerate(zip(batch, images_1st)): - if not args.highres_fix_latents_upscaling: - image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 - bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext) - batch_2nd.append(bd_2nd) - batch = batch_2nd + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) - # このバッチの情報を取り出す - return_latents, (step_first, _, _, _, init_image, mask_image, _, guide_image), \ - (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0] - noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) - prompts = [] - negative_prompts = [] - start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - noises = [torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - for _ in range(steps * scheduler_num_noises_per_step)] - seeds = [] - clip_prompts = [] + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + print("set clip_sample to True") + scheduler.config.clip_sample = True - if init_image is not None: # img2img? - i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - init_images = [] + # deviceを決定する + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない - if mask_image is not None: - mask_images = [] + # custom pipelineをコピったやつを生成する + vae.to(dtype).to(device) + text_encoder.to(dtype).to(device) + unet.to(dtype).to(device) + if clip_model is not None: + clip_model.to(dtype).to(device) + if vgg16_model is not None: + vgg16_model.to(dtype).to(device) + + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_default_muls.append(network_mul) + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights and i < len(args.network_weights): + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoder, unet, **net_kwargs + ) + else: + raise ValueError("No weight. Weight is required.") + if network is None: + return + + network.apply_to(text_encoder, unet) + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + networks.append(network) + else: + networks = [] + + # ControlNetの処理 + control_nets: List[ControlNetInfo] = [] + if args.control_net_models: + for i, model in enumerate(args.control_net_models): + prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + + ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model) + prep = original_control_net.load_preprocess(prep_type) + control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + + if args.opt_channels_last: + print(f"set optimizing: channels last") + text_encoder.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if clip_model is not None: + clip_model.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + if vgg16_model is not None: + vgg16_model.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.unet.to(memory_format=torch.channels_last) + cn.net.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + device, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + args.clip_skip, + clip_model, + args.clip_guidance_scale, + args.clip_image_guidance_scale, + vgg16_model, + args.vgg16_guidance_scale, + args.vgg16_guidance_layer, + ) + pipe.set_control_nets(control_nets) + print("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + embeds = next(iter(data.values())) + + if type(embeds) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}") + + num_vectors_per_token = embeds.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == num_vectors_per_token + ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") + assert ( + min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 + ), f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(token_ids[0], token_ids) + + token_ids_embeds.append((token_ids, embeds)) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds: + for token_id, embed in zip(token_ids, embeds): + token_embeds[token_id] = embed + + # promptを取得する + if args.from_file is not None: + print(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + elif args.prompt is not None: + prompt_list = [args.prompt] + else: + prompt_list = [] + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] else: - mask_images = None - else: - i2i_noises = None - init_images = None - mask_images = None + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() - if guide_image is not None: # CLIP image guided? - guide_images = [] - else: - guide_images = None + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + print(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) - # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする - all_images_are_same = True - all_masks_are_same = True - all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): - prompts.append(prompt) - negative_prompts.append(negative_prompt) - seeds.append(seed) - clip_prompts.append(clip_prompt) - - if init_image is not None: - init_images.append(init_image) - if i > 0 and all_images_are_same: - all_images_are_same = init_images[-2] is init_image - - if mask_image is not None: - mask_images.append(mask_image) - if i > 0 and all_masks_are_same: - all_masks_are_same = mask_images[-2] is mask_image - - if guide_image is not None: - if type(guide_image) is list: - guide_images.extend(guide_image) - all_guide_images_are_same = False - else: - guide_images.append(guide_image) - if i > 0 and all_guide_images_are_same: - all_guide_images_are_same = guide_images[-2] is guide_image - - # make start code - torch.manual_seed(seed) - start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - # make each noises - for j in range(steps * scheduler_num_noises_per_step): - noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) - - if i2i_noises is not None: # img2img noise - i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - noise_manager.reset_sampler_noises(noises) - - # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する - if init_images is not None and all_images_are_same: - init_images = init_images[0] - if mask_images is not None and all_masks_are_same: - mask_images = mask_images[0] - if guide_images is not None and all_guide_images_are_same: - guide_images = guide_images[0] - - # ControlNet使用時はguide imageをリサイズする - if control_nets: - # TODO resampleのメソッド - guide_images = guide_images if type(guide_images) == list else [guide_images] - guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] - if len(guide_images) == 1: - guide_images = guide_images[0] - - # generate - if networks: - for n, m in zip(networks, network_muls if network_muls else network_default_muls): - n.set_multiplier(m) - - images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, - output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, - vae_batch_size=args.vae_batch_size, return_latents=return_latents, - clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] - if highres_1st and not args.highres_fix_save_1st: # return images or latents return images - # save image - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate(zip(images, prompts, negative_prompts, seeds, clip_prompts)): - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" - else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + if args.image_path is not None: + print(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + print(f"loaded {len(init_images)} images for img2img") + else: + init_images = None - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + if args.mask_path is not None: + print(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + print(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and len(prompt_list) == 0 and not args.interactive: + print("get prompts from images' meta data") + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) - return images - - # 画像生成のプロンプトが一周するまでのループ - prompt_index = 0 - global_step = 0 - batch_data = [] - while args.interactive or prompt_index < len(prompt_list): - if len(prompt_list) == 0: - # interactive - valid = False - while not valid: - print("\nType prompt:") - try: - prompt = input() - except EOFError: - break - - valid = len(prompt.strip().split(' --')[0].strip()) > 0 - if not valid: # EOF, end app - break - else: - prompt = prompt_list[prompt_index] - - # parse prompt - width = args.W - height = args.H - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - prompt_args = prompt.strip().split(' --') - prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - - for parg in prompt_args[1:]: - try: - m = re.match(r'w (\d+)', parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - print(f"width: {width}") - continue - - m = re.match(r'h (\d+)', parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - print(f"height: {height}") - continue - - m = re.match(r's (\d+)', parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") - continue - - m = re.match(r'd ([\d,]+)', parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(',')] - print(f"seeds: {seeds}") - continue - - m = re.match(r'l ([\d\.]+)', parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - print(f"scale: {scale}") - continue - - m = re.match(r'nl ([\d\.]+|none|None)', parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == 'none': - negative_scale = None - else: - negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") - continue - - m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - print(f"strength: {strength}") - continue - - m = re.match(r'n (.+)', parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r'c (.+)', parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) - - if seeds is not None: - # 数が足りないなら繰り返す - if len(seeds) < args.images_per_prompt: - seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) - seeds = seeds[:args.images_per_prompt] - else: - if predefined_seeds is not None: - seeds = predefined_seeds[-args.images_per_prompt:] - predefined_seeds = predefined_seeds[:-args.images_per_prompt] - elif args.iter_same_seed: - seeds = [iter_seed] * args.images_per_prompt - else: - seeds = [random.randint(0, 0x7fffffff) for _ in range(args.images_per_prompt)] - if args.interactive: - print(f"seed: {seeds}") - - init_image = mask_image = guide_image = None - for seed in seeds: # images_per_promptの数だけ - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # 32単位に丸めたやつにresizeされるので踏襲する - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - print(f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます") + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l - if guide_images is not None: - if control_nets: # 複数件の場合あり - c = len(control_nets) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c:p * c + c] - else: - guide_image = guide_images[global_step % len(guide_images)] - elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: - if prev_image is None: - print("Generate 1st image without guide image.") - else: - print("Use previous image as guide image.") - guide_image = prev_image + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + if init_images is not None: + print(f"resize img2img source images to {args.W}*{args.H}") + init_images = resize_images(init_images, (args.W, args.H)) + if mask_images is not None: + print(f"resize img2img mask images to {args.W}*{args.H}") + mask_images = resize_images(mask_images, (args.W, args.H)) - b1 = BatchData(False, BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), - BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None)) - if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - process_batch(batch_data, highres_fix) - batch_data.clear() + if networks and mask_images: + # mask を領域情報として流用する、現在は1枚だけ対応 + # TODO 複数のnetwork classの混在時の考慮 + print("use mask as region") + # import cv2 + # for i in range(3): + # cv2.imshow("msk", np.array(mask_images[0])[:,:,i]) + # cv2.waitKey() + # cv2.destroyAllWindows() + networks[0].__class__.set_regions(networks, np.array(mask_images[0])) + mask_images = None - batch_data.append(b1) - if len(batch_data) == args.batch_size: - prev_image = process_batch(batch_data, highres_fix)[0] - batch_data.clear() + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) - global_step += 1 + print(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + guide_images = None + else: + guide_images = None - prompt_index += 1 + # seed指定時はseedを決めておく + if args.seed is not None: + random.seed(args.seed) + predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] + if len(predefined_seeds) == 1: + predefined_seeds[0] = args.seed + else: + predefined_seeds = None - if len(batch_data) > 0: - process_batch(batch_data, highres_fix) - batch_data.clear() + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 512 + if args.H is None: + args.H = 512 - print("done!") + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + print(f"iteration {gen_iter+1}/{args.n_iter}") + iter_seed = random.randint(0, 0x7FFFFFFF) + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + print("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + width_1st = int(ext.width * args.highres_fix_scale + 0.5) + height_1st = int(ext.height * args.highres_fix_scale + 0.5) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + ext_1st = BatchDataExt( + width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls + ) + batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st)) + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + print("process 2nd stage") + if args.highres_fix_latents_upscaling: + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + if not args.highres_fix_latents_upscaling: + image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定 + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image), + (width, height, steps, scale, negative_scale, strength, network_muls), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + )[0] + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts) + ): + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + while args.interactive or prompt_index < len(prompt_list): + if len(prompt_list) == 0: + # interactive + valid = False + while not valid: + print("\nType prompt:") + try: + prompt = input() + except EOFError: + break + + valid = len(prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + prompt = prompt_list[prompt_index] + + # parse prompt + width = args.W + height = args.H + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + prompt_args = prompt.strip().split(" --") + prompt = prompt_args[0] + print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + if seeds is not None: + # 数が足りないなら繰り返す + if len(seeds) < args.images_per_prompt: + seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) + seeds = seeds[: args.images_per_prompt] + else: + if predefined_seeds is not None: + seeds = predefined_seeds[-args.images_per_prompt :] + predefined_seeds = predefined_seeds[: -args.images_per_prompt] + elif args.iter_same_seed: + seeds = [iter_seed] * args.images_per_prompt + else: + seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)] + if args.interactive: + print(f"seed: {seeds}") + + init_image = mask_image = guide_image = None + for seed in seeds: # images_per_promptの数だけ + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # 32単位に丸めたやつにresizeされるので踏襲する + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + print( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: + if prev_image is None: + print("Generate 1st image without guide image.") + else: + print("Use previous image as guide image.") + guide_image = prev_image + + b1 = BatchData( + False, + BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataExt( + width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + print("done!") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") - parser.add_argument("--from_file", type=str, default=None, - help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む") - parser.add_argument("--interactive", action='store_true', help='interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)') - parser.add_argument("--no_preview", action='store_true', help='do not show generated image in interactive mode / 対話モードで画像を表示しない') - parser.add_argument("--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像") - parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") - parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") - parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") - parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action='store_true', help="sequential output file name / 生成画像のファイル名を連番にする") - parser.add_argument("--use_original_file_name", action='store_true', - help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける") - # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) - parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") - parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") - parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") - parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") - parser.add_argument("--vae_batch_size", type=float, default=None, - help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率") - parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") - parser.add_argument('--sampler', type=str, default='ddim', - choices=['ddim', 'pndm', 'lms', 'euler', 'euler_a', 'heun', 'dpm_2', 'dpm_2_a', 'dpmsolver', - 'dpmsolver++', 'dpmsingle', - 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'], - help=f'sampler (scheduler) type / サンプラー(スケジューラ)の種類') - parser.add_argument("--scale", type=float, default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale") - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") - parser.add_argument("--vae", type=str, default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") - parser.add_argument("--tokenizer_cache_dir", type=str, default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)") - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument("--seed", type=int, default=None, - help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed") - parser.add_argument("--iter_same_seed", action='store_true', - help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)') - parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する') - parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') - parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') - parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') - parser.add_argument("--opt_channels_last", action='store_true', - help='set channels last option to model / モデルにchannels lastを指定し最適化する') - parser.add_argument("--network_module", type=str, default=None, nargs='*', - help='additional network module to use / 追加ネットワークを使う時そのモジュール名') - parser.add_argument("--network_weights", type=str, default=None, nargs='*', - help='additional network weights to load / 追加ネットワークの重み') - parser.add_argument("--network_mul", type=float, default=None, nargs='*', - help='additional network multiplier / 追加ネットワークの効果の倍率') - parser.add_argument("--network_args", type=str, default=None, nargs='*', - help='additional argmuments for network (key=value) / ネットワークへの追加の引数') - parser.add_argument("--network_show_meta", action='store_true', - help='show metadata of network model / ネットワークモデルのメタデータを表示する') - parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*', - help='Embeddings files of Textual Inversion / Textual Inversionのembeddings') - parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') - parser.add_argument("--max_embeddings_multiples", type=int, default=None, - help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる') - parser.add_argument("--clip_guidance_scale", type=float, default=0.0, - help='enable CLIP guided SD, scale for guidance (DDIM, PNDM, LMS samplers only) / CLIP guided SDを有効にしてこのscaleを適用する(サンプラーはDDIM、PNDM、LMSのみ)') - parser.add_argument("--clip_image_guidance_scale", type=float, default=0.0, - help='enable CLIP guided SD by image, scale for guidance / 画像によるCLIP guided SDを有効にしてこのscaleを適用する') - parser.add_argument("--vgg16_guidance_scale", type=float, default=0.0, - help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する') - parser.add_argument("--vgg16_guidance_layer", type=int, default=20, - help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)') - parser.add_argument("--guide_image_path", type=str, default=None, nargs="*", - help="image to CLIP guidance / CLIP guided SDでガイドに使う画像") - parser.add_argument("--highres_fix_scale", type=float, default=None, - help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする") - parser.add_argument("--highres_fix_steps", type=int, default=28, - help="1st stage steps for highres fix / highres fixの最初のステージのステップ数") - parser.add_argument("--highres_fix_save_1st", action='store_true', - help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") - parser.add_argument("--highres_fix_latents_upscaling", action='store_true', - help="use latents upscaling for highres fix / highres fixでlatentで拡大する") - parser.add_argument("--negative_scale", type=float, default=None, - help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する") + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" + ) + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + ) + parser.add_argument( + "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + ) + parser.add_argument( + "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" + ) + parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--clip_guidance_scale", + type=float, + default=0.0, + help="enable CLIP guided SD, scale for guidance (DDIM, PNDM, LMS samplers only) / CLIP guided SDを有効にしてこのscaleを適用する(サンプラーはDDIM、PNDM、LMSのみ)", + ) + parser.add_argument( + "--clip_image_guidance_scale", + type=float, + default=0.0, + help="enable CLIP guided SD by image, scale for guidance / 画像によるCLIP guided SDを有効にしてこのscaleを適用する", + ) + parser.add_argument( + "--vgg16_guidance_scale", + type=float, + default=0.0, + help="enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する", + ) + parser.add_argument( + "--vgg16_guidance_layer", + type=int, + default=20, + help="layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + ) + parser.add_argument( + "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + ) - parser.add_argument("--control_net_models", type=str, default=None, nargs='*', - help='ControlNet models to use / 使用するControlNetのモデル名') - parser.add_argument("--control_net_preps", type=str, default=None, nargs='*', - help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名') - parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み') - parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*', - help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率') + parser.add_argument( + "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + ) + parser.add_argument( + "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + ) + parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) - return parser + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() - main(args) + args = parser.parse_args() + main(args) diff --git a/library/common_gui.py b/library/common_gui.py index 25e7379..addd360 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -840,6 +840,7 @@ def gradio_advanced_training(): xformers = gr.Checkbox(label='Use xformers', value=True) color_aug = gr.Checkbox(label='Color augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False) + min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1) with gr.Row(): bucket_no_upscale = gr.Checkbox( label="Don't upscale bucket resolution", value=True @@ -914,6 +915,7 @@ def gradio_advanced_training(): noise_offset, additional_parameters, vae_batch_size, + min_snr_gamma, ) @@ -949,13 +951,15 @@ def run_cmd_advanced_training(**kwargs): f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}' if int(kwargs.get('bucket_reso_steps', 64)) >= 1 else '', + f' --min_snr_gamma={int(kwargs.get("min_snr_gamma", 0))}' + if int(kwargs.get('min_snr_gamma', 0)) >= 1 + else '', ' --save_state' if kwargs.get('save_state') else '', ' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '', ' --color_aug' if kwargs.get('color_aug') else '', ' --flip_aug' if kwargs.get('flip_aug') else '', ' --shuffle_caption' if kwargs.get('shuffle_caption') else '', - ' --gradient_checkpointing' - if kwargs.get('gradient_checkpointing') + ' --gradient_checkpointing' if kwargs.get('gradient_checkpointing') else '', ' --full_fp16' if kwargs.get('full_fp16') else '', ' --xformers' if kwargs.get('xformers') else '', diff --git a/library/config_util.py b/library/config_util.py index e62bfb8..97bbb4a 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -4,6 +4,7 @@ from dataclasses import ( dataclass, ) import functools +import random from textwrap import dedent, indent import json from pathlib import Path @@ -56,6 +57,8 @@ class BaseSubsetParams: caption_dropout_rate: float = 0.0 caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: float = 0 @dataclass class DreamBoothSubsetParams(BaseSubsetParams): @@ -137,6 +140,8 @@ class ConfigSanitizer: "random_crop": bool, "shuffle_caption": bool, "keep_tokens": int, + "token_warmup_min": int, + "token_warmup_step": Any(float,int), } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -406,6 +411,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, """), " ") if is_dreambooth: @@ -422,9 +429,12 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu print(info) # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no for i, dataset in enumerate(datasets): print(f"[Dataset {i}]") dataset.make_buckets() + dataset.set_seed(seed) return DatasetGroup(datasets) @@ -491,7 +501,6 @@ def load_user_config(file: str) -> dict: return config - # for config test if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 7d42c5d..4d844be 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,14 +1,18 @@ import torch +import argparse -def apply_snr_weight(loss, noisy_latents, latents, gamma): - gamma = gamma - if gamma: - sigma = torch.sub(noisy_latents, latents) - zeros = torch.zeros_like(sigma) - alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) - sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) - snr = torch.div(alpha_mean_sq, sigma_mean_sq) - gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() - loss = loss * snr_weight - return loss +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + snr = torch.stack([all_snr[t] for t in timesteps]) + gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + loss = loss * snr_weight + return loss + +def add_custom_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨") diff --git a/library/model_util.py b/library/model_util.py index d1020c0..3d8e753 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -1046,10 +1046,14 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p key_count = len(state_dict.keys()) new_ckpt = {'state_dict': state_dict} - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] + # epoch and global_step are sometimes not int + try: + if 'epoch' in checkpoint: + epochs += checkpoint['epoch'] + if 'global_step' in checkpoint: + steps += checkpoint['global_step'] + except: + pass new_ckpt['epoch'] = epochs new_ckpt['global_step'] = steps diff --git a/library/train_util.py b/library/train_util.py index b42f894..e1a8e92 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -276,6 +276,8 @@ class BaseSubset: caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float, + token_warmup_min: int, + token_warmup_step: Union[float, int], ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -289,6 +291,9 @@ class BaseSubset: self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_tag_dropout_rate = caption_tag_dropout_rate + self.token_warmup_min = token_warmup_min # step=0におけるタグの数 + self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.img_count = 0 @@ -309,6 +314,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -324,6 +331,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) self.is_reg = is_reg @@ -351,6 +360,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -366,6 +377,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) self.metadata_file = metadata_file @@ -404,6 +417,10 @@ class BaseDataset(torch.utils.data.Dataset): self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ + self.current_step: int = 0 + self.max_train_steps: int = 0 + self.seed: int = 0 + # augmentation self.aug_helper = AugHelper() @@ -419,9 +436,19 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} + def set_seed(self, seed): + self.seed = seed + def set_current_epoch(self, epoch): + if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする + self.shuffle_buckets() self.current_epoch = epoch - self.shuffle_buckets() + + def set_current_step(self, step): + self.current_step = step + + def set_max_train_steps(self, max_train_steps): + self.max_train_steps = max_train_steps def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) @@ -452,7 +479,16 @@ class BaseDataset(torch.utils.data.Dataset): if is_drop_out: caption = "" else: - if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: + tokens = [t.strip() for t in caption.strip().split(",")] + if subset.token_warmup_step < 1: # 初回に上書きする + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = ( + math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step))) + + subset.token_warmup_min + ) + tokens = tokens[:tokens_len] def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: @@ -464,10 +500,10 @@ class BaseDataset(torch.utils.data.Dataset): return l fixed_tokens = [] - flex_tokens = [t.strip() for t in caption.strip().split(",")] + flex_tokens = tokens[:] if subset.keep_tokens > 0: fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = flex_tokens[subset.keep_tokens :] + flex_tokens = tokens[subset.keep_tokens :] if subset.shuffle_caption: random.shuffle(flex_tokens) @@ -637,6 +673,9 @@ class BaseDataset(torch.utils.data.Dataset): self._length = len(self.buckets_indices) def shuffle_buckets(self): + # set random seed for this epoch + random.seed(self.seed + self.current_epoch) + random.shuffle(self.buckets_indices) self.bucket_manager.shuffle() @@ -1043,7 +1082,7 @@ class DreamBoothDataset(BaseDataset): self.register_image(info, subset) n += info.num_repeats else: - info.num_repeats += 1 + info.num_repeats += 1 # rewrite registered info n += 1 if n >= num_train_images: break @@ -1104,6 +1143,8 @@ class FineTuningDataset(BaseDataset): # path情報を作る if os.path.exists(image_key): abs_path = image_key + elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"): + abs_path = os.path.splitext(image_key)[0] + ".npz" else: npz_path = os.path.join(subset.image_dir, image_key + ".npz") if os.path.exists(npz_path): @@ -1285,6 +1326,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.set_current_epoch(epoch) + def set_current_step(self, step): + for dataset in self.datasets: + dataset.set_current_step(step) + + def set_max_train_steps(self, max_train_steps): + for dataset in self.datasets: + dataset.set_max_train_steps(max_train_steps) + def disable_token_padding(self): for dataset in self.datasets: dataset.disable_token_padding() @@ -1292,37 +1341,55 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("Escape for exit. / Escキーで中断、終了します") + print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") - train_dataset.set_current_epoch(1) - k = 0 - indices = list(range(len(train_dataset))) - random.shuffle(indices) - for i, idx in enumerate(indices): - example = train_dataset[idx] - if example["latents"] is not None: - print(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid) in enumerate( - zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) - ): - print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') - if show_input_ids: - print(f"input ids: {iid}") - if example["images"] is not None: - im = example["images"][j] - print(f"image size: {im.size()}") - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - if os.name == "nt": # only windows - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == 27 or (example["images"] is None and i >= 8): + epoch = 1 + while True: + print(f"epoch: {epoch}") + + steps = (epoch - 1) * len(train_dataset) + 1 + indices = list(range(len(train_dataset))) + random.shuffle(indices) + + k = 0 + for i, idx in enumerate(indices): + train_dataset.set_current_epoch(epoch) + train_dataset.set_current_step(steps) + print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + + example = train_dataset[idx] + if example["latents"] is not None: + print(f"sample has latents from npz file: {example['latents'].size()}") + for j, (ik, cap, lw, iid) in enumerate( + zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) + ): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') + if show_input_ids: + print(f"input ids: {iid}") + if example["images"] is not None: + im = example["images"][j] + print(f"image size: {im.size()}") + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + if os.name == "nt": # only windows + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27 or k == ord("s") or k == ord("e"): + break + steps += 1 + + if k == ord("e"): + break + if k == 27 or (example["images"] is None and i >= 8): + k = 27 + break + if k == 27: break + epoch += 1 + def glob_images(directory, base="*"): img_paths = [] @@ -1331,8 +1398,8 @@ def glob_images(directory, base="*"): img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) else: img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) - # img_paths = list(set(img_paths)) # 重複を排除 - # img_paths.sort() + img_paths = list(set(img_paths)) # 重複を排除 + img_paths.sort() return img_paths @@ -1344,8 +1411,8 @@ def glob_images_pathlib(dir_path, recursive): else: for ext in IMAGE_EXTENSIONS: image_paths += list(dir_path.glob("*" + ext)) - # image_paths = list(set(image_paths)) # 重複を排除 - # image_paths.sort() + image_paths = list(set(image_paths)) # 重複を排除 + image_paths.sort() return image_paths @@ -1963,9 +2030,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - - parser.add_argument("--min_snr_gamma", type=float, default=5, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") - def verify_training_args(args: argparse.Namespace): @@ -2041,6 +2105,20 @@ def add_dataset_arguments( "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" ) + parser.add_argument( + "--token_warmup_min", + type=int, + default=1, + help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", + ) + + parser.add_argument( + "--token_warmup_step", + type=float, + default=0, + help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", + ) + if support_caption_dropout: # Textual Inversion はcaptionのdropoutをsupportしない # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに @@ -2975,3 +3053,24 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # endregion + + +# collate_fn用 epoch,stepはmultiprocessing.Value +class collater_class: + def __init__(self, epoch, step, dataset): + self.current_epoch = epoch + self.current_step = step + self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing + + def __call__(self, examples): + worker_info = torch.utils.data.get_worker_info() + # worker_info is None in the main process + if worker_info is not None: + dataset = worker_info.dataset + else: + dataset = self.dataset + + # set epoch and step + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] diff --git a/lora_gui.py b/lora_gui.py index 07ac1f5..af3b0ae 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -123,7 +123,9 @@ def save_configuration( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters,vae_batch_size, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -240,7 +242,9 @@ def open_configuration( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters,vae_batch_size, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -257,7 +261,7 @@ def open_configuration( with open(file_path, 'r') as f: my_data = json.load(f) print('Loading config...') - + # Update values to fix deprecated use_8bit_adam checkbox, set appropriate optimizer if it is set to True, etc. my_data = update_my_data(my_data) else: @@ -348,7 +352,9 @@ def train_model( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters,vae_batch_size, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): print_only_bool = True if print_only.get('label') == 'True' else False @@ -420,13 +426,15 @@ def train_model( [ f for f, lower_f in ( - (file, file.lower()) for file in os.listdir(os.path.join(train_data_dir, folder)) + (file, file.lower()) + for file in os.listdir( + os.path.join(train_data_dir, folder) + ) ) if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) ] ) - print(f'Folder {folder}: {num_images} images found') # Calculate the total number of steps for this folder @@ -591,6 +599,7 @@ def train_model( noise_offset=noise_offset, additional_parameters=additional_parameters, vae_batch_size=vae_batch_size, + min_snr_gamma=min_snr_gamma, ) run_cmd += run_cmd_sample( @@ -649,10 +658,12 @@ def lora_tab( v_parameterization, save_model_as, model_list, - ) = gradio_source_model(save_model_as_choices = [ - 'ckpt', - 'safetensors', - ]) + ) = gradio_source_model( + save_model_as_choices=[ + 'ckpt', + 'safetensors', + ] + ) with gr.Tab('Folders'): with gr.Row(): @@ -897,6 +908,7 @@ def lora_tab( noise_offset, additional_parameters, vae_batch_size, + min_snr_gamma, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -1015,6 +1027,7 @@ def lora_tab( sample_prompts, additional_parameters, vae_batch_size, + min_snr_gamma, ] button_open_config.click( @@ -1104,7 +1117,7 @@ def UI(**kwargs): if kwargs.get('inbrowser', False): launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False) if kwargs.get('listen', True): - launch_kwargs['server_name'] = "0.0.0.0" + launch_kwargs['server_name'] = '0.0.0.0' print(launch_kwargs) interface.launch(**launch_kwargs) @@ -1128,7 +1141,9 @@ if __name__ == '__main__': '--inbrowser', action='store_true', help='Open in browser' ) parser.add_argument( - '--listen', action='store_true', help='Launch gradio with server name 0.0.0.0, allowing LAN access' + '--listen', + action='store_true', + help='Launch gradio with server name 0.0.0.0, allowing LAN access', ) args = parser.parse_args() diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 2bd8659..7b74063 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -11,6 +11,8 @@ import numpy as np MIN_SV = 1e-6 +# Model save and load functions + def load_state_dict(file_name, dtype): if model_util.is_safetensors(file_name): sd = load_file(file_name) @@ -39,12 +41,13 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): torch.save(model, file_name) +# Indexing functions + def index_sv_cumulative(S, target): original_sum = float(torch.sum(S)) cumulative_sums = torch.cumsum(S, dim=0)/original_sum index = int(torch.searchsorted(cumulative_sums, target)) + 1 - if index >= len(S): - index = len(S) - 1 + index = max(1, min(index, len(S)-1)) return index @@ -54,8 +57,16 @@ def index_sv_fro(S, target): s_fro_sq = float(torch.sum(S_squared)) sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - if index >= len(S): - index = len(S) - 1 + index = max(1, min(index, len(S)-1)) + + return index + + +def index_sv_ratio(S, target): + max_sv = S[0] + min_sv = max_sv/target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S)-1)) return index @@ -125,26 +136,24 @@ def merge_linear(lora_down, lora_up, device): return weight +# Calculate new rank + def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): param_dict = {} if dynamic_method=="sv_ratio": # Calculate new dim and alpha based off ratio - max_sv = S[0] - min_sv = max_sv/dynamic_param - new_rank = max(torch.sum(S > min_sv).item(),1) + new_rank = index_sv_ratio(S, dynamic_param) + 1 new_alpha = float(scale*new_rank) elif dynamic_method=="sv_cumulative": # Calculate new dim and alpha based off cumulative sum - new_rank = index_sv_cumulative(S, dynamic_param) - new_rank = max(new_rank, 1) + new_rank = index_sv_cumulative(S, dynamic_param) + 1 new_alpha = float(scale*new_rank) elif dynamic_method=="sv_fro": # Calculate new dim and alpha based off sqrt sum of squares - new_rank = index_sv_fro(S, dynamic_param) - new_rank = min(max(new_rank, 1), len(S)-1) + new_rank = index_sv_fro(S, dynamic_param) + 1 new_alpha = float(scale*new_rank) else: new_rank = rank @@ -172,7 +181,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): param_dict["new_alpha"] = new_alpha param_dict["sum_retained"] = (s_rank)/s_sum param_dict["fro_retained"] = fro_percent - param_dict["max_ratio"] = S[0]/S[new_rank] + param_dict["max_ratio"] = S[0]/S[new_rank - 1] return param_dict diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index d7ed087..da5467d 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -112,7 +112,9 @@ def save_configuration( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters,vae_batch_size, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -225,7 +227,9 @@ def open_configuration( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters,vae_batch_size, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -320,7 +324,9 @@ def train_model( sample_every_n_epochs, sample_sampler, sample_prompts, - additional_parameters,vae_batch_size, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -376,7 +382,10 @@ def train_model( [ f for f, lower_f in ( - (file, file.lower()) for file in os.listdir(os.path.join(train_data_dir, folder)) + (file, file.lower()) + for file in os.listdir( + os.path.join(train_data_dir, folder) + ) ) if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) ] @@ -511,6 +520,7 @@ def train_model( noise_offset=noise_offset, additional_parameters=additional_parameters, vae_batch_size=vae_batch_size, + min_snr_gamma=min_snr_gamma, ) run_cmd += f' --token_string="{token_string}"' run_cmd += f' --init_word="{init_word}"' @@ -569,10 +579,12 @@ def ti_tab( v_parameterization, save_model_as, model_list, - ) = gradio_source_model(save_model_as_choices = [ - 'ckpt', - 'safetensors', - ]) + ) = gradio_source_model( + save_model_as_choices=[ + 'ckpt', + 'safetensors', + ] + ) with gr.Tab('Folders'): with gr.Row(): @@ -774,6 +786,7 @@ def ti_tab( noise_offset, additional_parameters, vae_batch_size, + min_snr_gamma, ) = gradio_advanced_training() color_aug.change( color_aug_changed, @@ -881,6 +894,7 @@ def ti_tab( sample_prompts, additional_parameters, vae_batch_size, + min_snr_gamma, ] button_open_config.click( diff --git a/train_db - Copy.py b/train_db - Copy.py new file mode 100644 index 0000000..f441d5d --- /dev/null +++ b/train_db - Copy.py @@ -0,0 +1,426 @@ +# DreamBooth training +# XXX dropped option: fine_tune + +import gc +import time +import argparse +import itertools +import math +import os +import toml +from multiprocessing import Value + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +import diffusers +from diffusers import DDPMScheduler + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, False) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenizer = train_util.load_tokenizer(args) + + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + + if args.no_token_padding: + train_dataset_group.disable_token_padding() + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + print("prepare accelerator") + + if args.gradient_accumulation_steps > 1: + print( + f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" + ) + print( + f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" + ) + + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # 学習を準備する:モデルを適切な状態にする + train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 + unet.requires_grad_(True) # 念のため追加 + text_encoder.requires_grad_(train_text_encoder) + if not train_text_encoder: + print("Text Encoder is not trained.") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + if train_text_encoder: + trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + trainable_params = unet.parameters() + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + + # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth") + + loss_list = [] + loss_total = 0.0 + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch+1 + + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 + unet.train() + # train==True is required to enable gradient_checkpointing + if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: + text_encoder.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + # 指定したステップ数でText Encoderの学習を止める + if global_step == args.stop_text_encoder_training: + print(f"stop text encoder training at step {global_step}") + if not args.gradient_checkpointing: + text_encoder.train(False) + text_encoder.requires_grad_(False) + + with accelerator.accumulate(unet): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Get the text embedding for conditioning + with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype + ) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + if train_text_encoder: + params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end( + args, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, False, True) + train_util.add_training_arguments(parser, True) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--no_token_padding", + action="store_true", + help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)", + ) + parser.add_argument( + "--stop_text_encoder_training", + type=int, + default=None, + help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_db.py b/train_db.py index bd33731..b3eead9 100644 --- a/train_db.py +++ b/train_db.py @@ -8,9 +8,9 @@ import itertools import math import os import toml +from multiprocessing import Value from tqdm import tqdm -from library.custom_train_functions import apply_snr_weight import torch from accelerate.utils import set_seed import diffusers @@ -22,10 +22,8 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) - - -def collate_fn(examples): - return examples[0] +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def train(args): @@ -60,6 +58,11 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + if args.no_token_padding: train_dataset_group.disable_token_padding() @@ -153,16 +156,21 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + if args.stop_text_encoder_training is None: args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end @@ -230,7 +238,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + current_epoch.value = epoch + 1 # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -239,6 +247,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") @@ -291,8 +300,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - - loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -393,6 +403,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--no_token_padding", diff --git a/train_network - Copy.py b/train_network - Copy.py new file mode 100644 index 0000000..20ad2c4 --- /dev/null +++ b/train_network - Copy.py @@ -0,0 +1,710 @@ +from torch.nn.parallel import DistributedDataParallel as DDP +import importlib +import argparse +import gc +import math +import os +import random +import time +import json +import toml +from multiprocessing import Value + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +from diffusers import DDPMScheduler + +import library.train_util as train_util +from library.train_util import ( + DreamBoothDataset, +) +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = {"loss/current": current_loss, "loss/average": avr_loss} + + if args.network_train_unet_only: + logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0]) + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) + else: + logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0]) + logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder + + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet. + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None + + if args.seed is not None: + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + + # work on low-ram device + if args.lowram: + text_encoder.to("cuda") + unet.to("cuda") + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # prepare network + import sys + + sys.path.append(os.path.dirname(__file__)) + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) + if network is None: + return + + if args.network_weights is not None: + print("load network weights from:", args.network_weights) + network.load_weights(args.network_weights) + + train_unet = not args.network_train_text_encoder_only + train_text_encoder = not args.network_train_unet_only + network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + network.enable_gradient_checkpointing() # may have no effect + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + if is_main_process: + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + network.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if train_unet and train_text_encoder: + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + elif train_unet: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) + elif train_text_encoder: + text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, network, optimizer, train_dataloader, lr_scheduler + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + text_encoder.train() + + # set top parameter requires_grad = True for gradient checkpointing works + if type(text_encoder) == DDP: + text_encoder.module.text_model.embeddings.requires_grad_(True) + else: + text_encoder.text_model.embeddings.requires_grad_(True) + else: + unet.eval() + text_encoder.eval() + + # support DistributedDataParallel + if type(text_encoder) == DDP: + text_encoder = text_encoder.module + unet = unet.module + network = network.module + + network.prepare_grad_etc(text_encoder, unet) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + if is_main_process: + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + # TODO refactor metadata creation and move to util + metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_reg_images": train_dataset_group.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_gradient_checkpointing": args.gradient_checkpointing, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not use this value + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_cache_latents": bool(args.cache_latents), + "ss_seed": args.seed, + "ss_lowram": args.lowram, + "ss_noise_offset": args.noise_offset, + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_max_grad_norm": args.max_grad_norm, + "ss_caption_dropout_rate": args.caption_dropout_rate, + "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, + "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, + "ss_face_crop_aug_range": args.face_crop_aug_range, + "ss_prior_loss_weight": args.prior_loss_weight, + } + + if use_user_config: + # save metadata of multiple datasets + # NOTE: pack "ss_datasets" value as json one time + # or should also pack nested collections as json? + datasets_metadata = [] + tag_frequency = {} # merge tag frequency for metadata editor + dataset_dirs_info = {} # merge subset dirs for metadata editor + + for dataset in train_dataset_group.datasets: + is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) + dataset_metadata = { + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, + } + + subsets_metadata = [] + for subset in dataset.subsets: + subset_metadata = { + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, + } + + image_dir_or_metadata_file = None + if subset.image_dir: + image_dir = os.path.basename(subset.image_dir) + subset_metadata["image_dir"] = image_dir + image_dir_or_metadata_file = image_dir + + if is_dreambooth_dataset: + subset_metadata["class_tokens"] = subset.class_tokens + subset_metadata["is_reg"] = subset.is_reg + if subset.is_reg: + image_dir_or_metadata_file = None # not merging reg dataset + else: + metadata_file = os.path.basename(subset.metadata_file) + subset_metadata["metadata_file"] = metadata_file + image_dir_or_metadata_file = metadata_file # may overwrite + + subsets_metadata.append(subset_metadata) + + # merge dataset dir: not reg subset only + # TODO update additional-network extension to show detailed dataset config from metadata + if image_dir_or_metadata_file is not None: + # datasets may have a certain dir multiple times + v = image_dir_or_metadata_file + i = 2 + while v in dataset_dirs_info: + v = image_dir_or_metadata_file + f" ({i})" + i += 1 + image_dir_or_metadata_file = v + + dataset_dirs_info[image_dir_or_metadata_file] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + + dataset_metadata["subsets"] = subsets_metadata + datasets_metadata.append(dataset_metadata) + + # merge tag frequency: + for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): + # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える + # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない + # なので、ここで複数datasetの回数を合算してもあまり意味はない + if ds_dir_name in tag_frequency: + continue + tag_frequency[ds_dir_name] = ds_freq_for_dir + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + metadata["ss_tag_frequency"] = json.dumps(tag_frequency) + metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) + else: + # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir + assert ( + len(train_dataset_group.datasets) == 1 + ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。" + + dataset = train_dataset_group.datasets[0] + + dataset_dirs_info = {} + reg_dataset_dirs_info = {} + if use_dreambooth_method: + for subset in dataset.subsets: + info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info + info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + else: + for subset in dataset.subsets: + dataset_dirs_info[os.path.basename(subset.metadata_file)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + metadata.update( + { + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), + } + ) + + # add extra args + if args.network_args: + metadata["ss_network_args"] = json.dumps(net_kwargs) + # for key, value in net_kwargs.items(): + # metadata["ss_arg_" + key] = value + + # model name and hash + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + + metadata = {k: str(v) for k, v in metadata.items()} + + # make minimum metadata for filtering + minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"] + minimum_metadata = {} + for key in minimum_keys: + if key in metadata: + minimum_metadata[key] = metadata[key] + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + if accelerator.is_main_process: + accelerator.init_trackers("network_train") + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + for epoch in range(num_train_epochs): + if is_main_process: + print(f"epoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch+1 + + metadata["ss_epoch"] = str(epoch + 1) + + network.on_epoch_start(text_encoder, unet) + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(network): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + ) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + metadata["ss_training_finished_at"] = str(time.time()) + print(f"saving checkpoint: {ckpt_file}") + unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + if is_main_process: + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # end of epoch + + metadata["ss_epoch"] = str(num_train_epochs) + metadata["ss_training_finished_at"] = str(time.time()) + + if is_main_process: + network = unwrap_model(network) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model to {ckpt_file}") + network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata) + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール") + parser.add_argument( + "--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)" + ) + parser.add_argument( + "--network_alpha", + type=float, + default=1, + help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)", + ) + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" + ) + parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") + parser.add_argument( + "--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + ) + parser.add_argument( + "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_network.py b/train_network.py index ff990d9..423649e 100644 --- a/train_network.py +++ b/train_network.py @@ -8,6 +8,7 @@ import random import time import json import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -15,7 +16,6 @@ from accelerate.utils import set_seed from diffusers import DDPMScheduler import library.train_util as train_util -from library.custom_train_functions import apply_snr_weight from library.train_util import ( DreamBoothDataset, ) @@ -24,10 +24,8 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) - - -def collate_fn(examples): - return examples[0] +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight # TODO 他のスクリプトと共通化する @@ -101,6 +99,11 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -186,11 +189,12 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -201,6 +205,9 @@ def train(args): if is_main_process: print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -489,22 +496,23 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) - if accelerator.is_main_process: accelerator.init_trackers("network_train") loss_list = [] loss_total = 0.0 + del train_dataset_group for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + current_epoch.value = epoch+1 metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: @@ -529,7 +537,6 @@ def train(args): # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -549,8 +556,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - - loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -655,6 +663,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument( diff --git a/train_textual_inversion - Copy.py b/train_textual_inversion - Copy.py new file mode 100644 index 0000000..681bc62 --- /dev/null +++ b/train_textual_inversion - Copy.py @@ -0,0 +1,589 @@ +import importlib +import argparse +import gc +import math +import os +import toml +from multiprocessing import Value + +from tqdm import tqdm +import torch +from accelerate.utils import set_seed +import diffusers +from diffusers import DDPMScheduler + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight + +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +def train(args): + if args.output_name is None: + args.output_name = args.token_string + use_template = args.use_object_template or args.use_style_template + + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) + + tokenizer = train_util.load_tokenizer(args) + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + + # Convert the init_word to token_id + if args.init_word is not None: + init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) + if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: + print( + f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" + ) + else: + init_token_ids = None + + # add new word to tokenizer, count is num_vectors_per_token + token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] + num_added_tokens = tokenizer.add_tokens(token_strings) + assert ( + num_added_tokens == args.num_vectors_per_token + ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" + + token_ids = tokenizer.convert_tokens_to_ids(token_strings) + print(f"tokens are added: {token_ids}") + assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" + assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_token_ids is not None: + for i, token_id in enumerate(token_ids): + token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]] + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + + # load weights + if args.weights is not None: + embeddings = load_weights(args.weights) + assert len(token_ids) == len( + embeddings + ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" + # print(token_ids, embeddings.size()) + for token_id, embedding in zip(token_ids, embeddings): + token_embeds[token_id] = embedding + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + print(f"weighs loaded") + + print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 + if use_template: + print("use template for training captions. is object: {args.use_object_template}") + templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small + replace_to = " ".join(token_strings) + captions = [] + for tmpl in templates: + captions.append(tmpl.format(replace_to)) + train_dataset_group.add_replacement("", captions) + + if args.num_vectors_per_token > 1: + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + else: + if args.num_vectors_per_token > 1: + replace_to = " ".join(token_strings) + train_dataset_group.add_replacement(args.token_string, replace_to) + prompt_replacement = (args.token_string, replace_to) + else: + prompt_replacement = None + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, show_input_ids=True) + return + if len(train_dataset_group) == 0: + print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents(vae, args.vae_batch_size) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + trainable_params = text_encoder.get_input_embeddings().parameters() + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # acceleratorがなんかよろしくやってくれるらしい + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] + # print(len(index_no_updates), torch.sum(index_no_updates)) + orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.requires_grad_(True) + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + else: + unet.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + text_encoder.to(weight_dtype) + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch+1 + + text_encoder.train() + + loss_total = 0 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(text_encoder): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + # weight_dtype) use float instead of fp16/bf16 because text encoder is float + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = text_encoder.get_input_embeddings().parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Let's make sure we don't update any embedding weights besides the newly added token + with torch.no_grad(): + unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + index_no_updates + ] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} + if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] + ) + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + + if args.save_every_n_epochs is not None: + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) + + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + ) + + # end of epoch + + is_main_process = accelerator.is_main_process + if is_main_process: + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + "." + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model to {ckpt_file}") + save_weights(ckpt_file, updated_embs, save_dtype) + print("model saved.") + + +def save_weights(file, updated_embs, save_dtype): + state_dict = {"emb_params": updated_embs} + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file) + else: + torch.save(state_dict, file) # can be loaded in Web UI + + +def load_weights(file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + data = load_file(file) + else: + # compatible to Web UI's file format + data = torch.load(file, map_location="cpu") + if type(data) != dict: + raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") + + if "string_to_param" in data: # textual inversion embeddings + data = data["string_to_param"] + if hasattr(data, "_parameters"): # support old PyTorch? + data = getattr(data, "_parameters") + + emb = next(iter(data.values())) + if type(emb) != torch.Tensor: + raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}") + + if len(emb.size()) == 1: + emb = emb.unsqueeze(0) + + return emb + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, False) + train_util.add_training_arguments(parser, True) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="pt", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)", + ) + + parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み") + parser.add_argument( + "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数" + ) + parser.add_argument( + "--token_string", + type=str, + default=None, + help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること", + ) + parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可") + parser.add_argument( + "--use_object_template", + action="store_true", + help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する", + ) + parser.add_argument( + "--use_style_template", + action="store_true", + help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 85f0d57..f279370 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -4,6 +4,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -17,6 +18,8 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight imagenet_templates_small = [ "a photo of a {}", @@ -71,10 +74,6 @@ imagenet_style_templates_small = [ ] -def collate_fn(examples): - return examples[0] - - def train(args): if args.output_name is None: args.output_name = args.token_string @@ -185,6 +184,11 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: print("use template for training captions. is object: {args.use_object_template}") @@ -250,7 +254,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -260,6 +264,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -331,12 +338,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + current_epoch.value = epoch+1 text_encoder.train() loss_total = 0 + for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: @@ -377,6 +386,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights @@ -534,6 +546,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--save_model_as",