diff --git a/README.md b/README.md index 7deeebd..3ad9e58 100644 --- a/README.md +++ b/README.md @@ -390,4 +390,8 @@ options: - Fixed a bug where prior_loss_weight was applied to learning images. Sorry for the inconvenience. - Compatible with Stable Diffusion v2.0. Add the `--v2` option. If you are using `768-v-ema.ckpt` or `stable-diffusion-2` instead of `stable-diffusion-v2-base`, add `--v_parameterization` as well. Learn more about other options. - Added options related to the learning rate scheduler. - - You can download and use DiffUsers models directly from Hugging Face. In addition, DiffUsers models can be saved during training. \ No newline at end of file + - You can download and use DiffUsers models directly from Hugging Face. In addition, DiffUsers models can be saved during training. +* 11/29 (v12) update: + - stop training text encoder at specified step (`--stop_text_encoder_training=`) + - tqdm smoothing + - updated fine tuning script to support SD2.0 768/v \ No newline at end of file diff --git a/train_db_fixed.py b/train_db_fixed.py index 25e39fa..ef6fc63 100644 --- a/train_db_fixed.py +++ b/train_db_fixed.py @@ -6,10 +6,11 @@ # v8: supports Diffusers 0.7.2 # v9: add bucketing option # v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth -# v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization +# v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization # add lr scheduler options, change handling folder/file caption, support loading DiffUser model from Huggingface # support save_ever_n_epochs/save_state in DiffUsers model # fix the issue that prior_loss_weight is applyed to train images +# v12: stop train text encode, tqdm smoothing import time from torch.autograd.function import Function @@ -39,33 +40,6 @@ from torch import einsum TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う -# DiffUsers版StableDiffusionのモデルパラメータ -NUM_TRAIN_TIMESTEPS = 1000 -BETA_START = 0.00085 -BETA_END = 0.0120 - -UNET_PARAMS_MODEL_CHANNELS = 320 -UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] -UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] -UNET_PARAMS_IMAGE_SIZE = 32 # unused -UNET_PARAMS_IN_CHANNELS = 4 -UNET_PARAMS_OUT_CHANNELS = 4 -UNET_PARAMS_NUM_RES_BLOCKS = 2 -UNET_PARAMS_CONTEXT_DIM = 768 -UNET_PARAMS_NUM_HEADS = 8 - -VAE_PARAMS_Z_CHANNELS = 4 -VAE_PARAMS_RESOLUTION = 256 -VAE_PARAMS_IN_CHANNELS = 3 -VAE_PARAMS_OUT_CH = 3 -VAE_PARAMS_CH = 128 -VAE_PARAMS_CH_MULT = [1, 2, 4, 4] -VAE_PARAMS_NUM_RES_BLOCKS = 2 - -# V2 -V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] -V2_UNET_PARAMS_CONTEXT_DIM = 1024 - # checkpointファイル名 LAST_CHECKPOINT_NAME = "last.ckpt" LAST_STATE_NAME = "last-state" @@ -693,6 +667,34 @@ def replace_unet_cross_attn_to_xformers(): # region checkpoint変換、読み込み、書き込み ############################### +# DiffUsers版StableDiffusionのモデルパラメータ +NUM_TRAIN_TIMESTEPS = 1000 +BETA_START = 0.00085 +BETA_END = 0.0120 + +UNET_PARAMS_MODEL_CHANNELS = 320 +UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] +UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] +UNET_PARAMS_IMAGE_SIZE = 32 # unused +UNET_PARAMS_IN_CHANNELS = 4 +UNET_PARAMS_OUT_CHANNELS = 4 +UNET_PARAMS_NUM_RES_BLOCKS = 2 +UNET_PARAMS_CONTEXT_DIM = 768 +UNET_PARAMS_NUM_HEADS = 8 + +VAE_PARAMS_Z_CHANNELS = 4 +VAE_PARAMS_RESOLUTION = 256 +VAE_PARAMS_IN_CHANNELS = 3 +VAE_PARAMS_OUT_CH = 3 +VAE_PARAMS_CH = 128 +VAE_PARAMS_CH_MULT = [1, 2, 4, 4] +VAE_PARAMS_NUM_RES_BLOCKS = 2 + +# V2 +V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] +V2_UNET_PARAMS_CONTEXT_DIM = 1024 + + # region StableDiffusion->Diffusersの変換コード # convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0) @@ -1408,9 +1410,13 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path): return checkpoint -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path): +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) state_dict = checkpoint["state_dict"] + if dtype is not None: + for k, v in state_dict.items(): + if type(v) is torch.Tensor: + state_dict[k] = v.to(dtype) # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(v2) @@ -1854,10 +1860,15 @@ def train(args): print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps") + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) + # v12で更新:clip_sample=Falseに + # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる + # 既存の1.4/1.5/2.0はすべてschdulerのconfigは(クラス名を除いて)同じ + # よくソースを見たら学習時は関係ないや(;'∀')  + noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", + num_train_timesteps=1000, clip_sample=False) if accelerator.is_main_process: accelerator.init_trackers("dreambooth") @@ -1891,13 +1902,16 @@ def train(args): # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning - if args.clip_skip is None: - encoder_hidden_states = text_encoder(batch["input_ids"])[0] - else: - enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + # 指定したステップ数までText Encoderを学習する + train_text_encoder = args.stop_text_encoder_training is None or global_step < args.stop_text_encoder_training + with torch.set_grad_enabled(train_text_encoder): + # Get the text embedding for conditioning + if args.clip_skip is None: + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + else: + enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -1954,6 +1968,9 @@ def train(args): progress_bar.update(1) global_step += 1 + if global_step == args.stop_text_encoder_training: + print(f"stop text encoder training at step {global_step}") + current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} @@ -2052,6 +2069,7 @@ if __name__ == '__main__': parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") parser.add_argument("--no_token_padding", action="store_true", help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") + parser.add_argument("--stop_text_encoder_training", type=int, default=None, help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数") parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") parser.add_argument("--face_crop_aug_range", type=str, default=None,