diff --git a/README.md b/README.md index b93ebb5..c179bbf 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,15 @@ my_asd_dog_dreambooth `- dog8.png ``` -## Execution +## GUI + +There is now support for GUI based training using gradio. You can start the GUI interface by running: + +```powershell +python .\dreambooth_gui.py +``` + +## Manual Script Execution ### SD1.5 example @@ -276,22 +284,21 @@ Refer to this url for more details about finetuning: https://note.com/kohya_ss/n ## Options list ```txt -usage: train_db_fixed.py [-h] [--v2] [--v_parameterization] - [--pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH] [--fine_tuning] - [--shuffle_caption] [--caption_extention CAPTION_EXTENTION] +usage: train_db_fixed.py [-h] [--v2] [--v_parameterization] [--pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH] + [--fine_tuning] [--shuffle_caption] [--caption_extention CAPTION_EXTENTION] [--caption_extension CAPTION_EXTENSION] [--train_data_dir TRAIN_DATA_DIR] - [--reg_data_dir REG_DATA_DIR] [--dataset_repeats DATASET_REPEATS] [--output_dir OUTPUT_DIR] - [--use_safetensors] [--save_every_n_epochs SAVE_EVERY_N_EPOCHS] [--save_state] [--resume RESUME] + [--reg_data_dir REG_DATA_DIR] [--dataset_repeats DATASET_REPEATS] [--output_dir OUTPUT_DIR] + [--use_safetensors] [--save_every_n_epochs SAVE_EVERY_N_EPOCHS] [--save_state] [--resume RESUME] [--prior_loss_weight PRIOR_LOSS_WEIGHT] [--no_token_padding] [--stop_text_encoder_training STOP_TEXT_ENCODER_TRAINING] [--color_aug] [--flip_aug] [--face_crop_aug_range FACE_CROP_AUG_RANGE] [--random_crop] [--debug_dataset] - [--resolution RESOLUTION] [--train_batch_size TRAIN_BATCH_SIZE] [--use_8bit_adam] - [--mem_eff_attn] [--xformers] [--vae VAE] [--cache_latents] [--enable_bucket] - [--min_bucket_reso MIN_BUCKET_RESO] [--max_bucket_reso MAX_BUCKET_RESO] - [--learning_rate LEARNING_RATE] [--max_train_steps MAX_TRAIN_STEPS] [--seed SEED] - [--gradient_checkpointing] [--mixed_precision {no,fp16,bf16}] - [--save_precision {None,float,fp16,bf16}] [--clip_skip CLIP_SKIP] [--logging_dir LOGGING_DIR] - [--log_prefix LOG_PREFIX] [--lr_scheduler LR_SCHEDULER] [--lr_warmup_steps LR_WARMUP_STEPS] + [--resolution RESOLUTION] [--train_batch_size TRAIN_BATCH_SIZE] [--use_8bit_adam] [--mem_eff_attn] + [--xformers] [--vae VAE] [--cache_latents] [--enable_bucket] [--min_bucket_reso MIN_BUCKET_RESO] + [--max_bucket_reso MAX_BUCKET_RESO] [--learning_rate LEARNING_RATE] + [--max_train_steps MAX_TRAIN_STEPS] [--seed SEED] [--gradient_checkpointing] + [--mixed_precision {no,fp16,bf16}] [--full_fp16] [--save_precision {None,float,fp16,bf16}] + [--clip_skip CLIP_SKIP] [--logging_dir LOGGING_DIR] [--log_prefix LOG_PREFIX] + [--lr_scheduler LR_SCHEDULER] [--lr_warmup_steps LR_WARMUP_STEPS] options: -h, --help show this help message and exit @@ -303,7 +310,7 @@ options: --fine_tuning fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする --shuffle_caption shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする --caption_extention CAPTION_EXTENTION - extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子(スペルミスを 残してあります) + extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります) --caption_extension CAPTION_EXTENSION extension of caption files / 読み込むcaptionファイルの拡張子 --train_data_dir TRAIN_DATA_DIR @@ -314,10 +321,9 @@ options: repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数 --output_dir OUTPUT_DIR directory to output trained model / 学習後のモデル出力先ディレクトリ - --use_safetensors use safetensors format for StableDiffusion checkpoint / - StableDiffusionのcheckpointをsafetensors形式で保存する + --use_safetensors use safetensors format to save / checkpoint、モデルをsafetensors形式で保存する --save_every_n_epochs SAVE_EVERY_N_EPOCHS - save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存します + save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する --save_state save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する --resume RESUME saved state to resume training / 学習再開するモデルのstate @@ -333,17 +339,17 @@ options: enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0) --random_crop enable random crop (for style training in face-centered crop augmentation) / - ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する) - --debug_dataset show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない) + ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する) + --debug_dataset show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない) --resolution RESOLUTION - resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高 さ'指定) + resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ' 指定) --train_batch_size TRAIN_BATCH_SIZE batch size for training (1 means one train or reg data, not train/reg pair) / 学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習) - --use_8bit_adam use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要) - --mem_eff_attn use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う + --use_8bit_adam use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインス トールが必要) + --mem_eff_attn use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う --xformers use xformers for CrossAttention / CrossAttentionにxformersを使う - --vae VAE path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ + --vae VAE path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ --cache_latents cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可) --enable_bucket enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする @@ -360,19 +366,20 @@ options: enable gradient checkpointing / grandient checkpointingを有効にする --mixed_precision {no,fp16,bf16} use mixed precision / 混合精度を使う場合、その精度 + --full_fp16 fp16 training including gradients / 勾配も含めてfp16で学習する --save_precision {None,float,fp16,bf16} precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効) --clip_skip CLIP_SKIP - use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を 用いる(nは1以上) + use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上) --logging_dir LOGGING_DIR enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する --log_prefix LOG_PREFIX add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列 --lr_scheduler LR_SCHEDULER - scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, - polynomial, constant (default), constant_with_warmup + scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, + constant (default), constant_with_warmup --lr_warmup_steps LR_WARMUP_STEPS Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0) diff --git a/finetune_gui.py b/dreambooth_gui.py similarity index 92% rename from finetune_gui.py rename to dreambooth_gui.py index 1b938f7..344494f 100644 --- a/finetune_gui.py +++ b/dreambooth_gui.py @@ -34,7 +34,9 @@ def save_variables( cache_latent, caption_extention, use_safetensors, - enable_bucket + enable_bucket, + gradient_checkpointing, + full_fp16 ): # Return the values of the variables as a dictionary variables = { @@ -61,7 +63,9 @@ def save_variables( "cache_latent": cache_latent, "caption_extention": caption_extention, "use_safetensors": use_safetensors, - "enable_bucket": enable_bucket + "enable_bucket": enable_bucket, + "gradient_checkpointing": gradient_checkpointing, + "full_fp16": full_fp16 } # Save the data to the selected file @@ -100,6 +104,8 @@ def load_variables(file_path): my_data.get("caption_extention", None), my_data.get("use_safetensors", None), my_data.get("enable_bucket", None), + my_data.get("gradient_checkpointing", None), + my_data.get("full_fp16", None), ) @@ -127,7 +133,9 @@ def train_model( cache_latent, caption_extention, use_safetensors, - enable_bucket + enable_bucket, + gradient_checkpointing, + full_fp16 ): def save_inference_file(output_dir, v2, v_model): # Copy inference model for v2 if required @@ -189,6 +197,10 @@ def train_model( run_cmd += " --use_safetensors" if enable_bucket: run_cmd += " --enable_bucket" + if gradient_checkpointing: + run_cmd += " --gradient_checkpointing" + if full_fp16: + run_cmd += " --full_fp16" run_cmd += f" --pretrained_model_name_or_path={pretrained_model_name_or_path}" run_cmd += f" --train_data_dir={train_data_dir}" run_cmd += f" --reg_data_dir={reg_data_dir}" @@ -287,8 +299,8 @@ with interface: with gr.Row(): config_file_name = gr.Textbox( label="Config file name") - b1 = gr.Button("Load config") - b2 = gr.Button("Save config") + button_load_config = gr.Button("Load config") + button_save_config = gr.Button("Save config") with gr.Tab("Source model"): # Define the input elements with gr.Row(): @@ -399,6 +411,12 @@ with interface: cache_latent_input = gr.Checkbox( label="Cache latent", value=True ) + gradient_checkpointing_input = gr.Checkbox( + label="Gradient checkpointing", value=False + ) + full_fp16_input = gr.Checkbox( + label="Full fp16 training (experimental)", value=False + ) with gr.Tab("Model conversion"): convert_to_safetensors_input = gr.Checkbox( @@ -408,9 +426,9 @@ with interface: label="Convert to CKPT", value=False ) - b3 = gr.Button("Run") + button_run = gr.Button("Run") - b1.click( + button_load_config.click( load_variables, inputs=[config_file_name], outputs=[ @@ -437,11 +455,13 @@ with interface: cache_latent_input, caption_extention_input, use_safetensors_input, - enable_bucket_input + enable_bucket_input, + gradient_checkpointing_input, + full_fp16_input ] ) - b2.click( + button_save_config.click( save_variables, inputs=[ config_file_name, @@ -468,10 +488,12 @@ with interface: cache_latent_input, caption_extention_input, use_safetensors_input, - enable_bucket_input + enable_bucket_input, + gradient_checkpointing_input, + full_fp16_input ] ) - b3.click( + button_run.click( train_model, inputs=[ pretrained_model_name_or_path_input, @@ -497,7 +519,9 @@ with interface: cache_latent_input, caption_extention_input, use_safetensors_input, - enable_bucket_input + enable_bucket_input, + gradient_checkpointing_input, + full_fp16_input ] ) diff --git a/model_util.py b/model_util.py index 74650bf..9610c90 100644 --- a/model_util.py +++ b/model_util.py @@ -813,7 +813,7 @@ def convert_vae_state_dict(vae_state_dict): # endregion -# region 自作のモデル読み書き +# region 自作のモデル読み書きなど def is_safetensors(path): return os.path.splitext(path)[1].lower() == '.safetensors' @@ -1046,7 +1046,7 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p return key_count -def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None): +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): if vae is None: vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") pipeline = StableDiffusionPipeline( @@ -1059,7 +1059,7 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod feature_extractor=None, requires_safety_checker=None, ) - pipeline.save_pretrained(output_dir) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) VAE_PREFIX = "first_stage_model." @@ -1117,6 +1117,7 @@ def get_epoch_ckpt_name(use_safetensors, epoch): def get_last_ckpt_name(use_safetensors): return f"last" + (".safetensors" if use_safetensors else ".ckpt") + # endregion diff --git a/requirements.txt b/requirements.txt index c2733e3..cfb2bdb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ -accelerate==0.14.0 +accelerate==0.15.0 transformers==4.25.1 ftfy albumentations opencv-python einops -diffusers[torch]==0.9.0 +diffusers[torch]==0.10.2 pytorch_lightning bitsandbytes==0.35.0 tensorboard -safetensors==0.2.5 +safetensors==0.2.6 gradio altair \ No newline at end of file diff --git a/tools/prune.py b/tools/prune.py index 199960b..6493bb3 100644 --- a/tools/prune.py +++ b/tools/prune.py @@ -1,4 +1,3 @@ -import os import argparse import torch from tqdm import tqdm @@ -23,7 +22,7 @@ del theta_prune if args.half: print("Halving model...") - state_dict = {k: v.half() for k, v in theta.items()} + state_dict = {k: v.half() for k, v in tqdm(theta.items(), desc="Halving weights")} else: state_dict = theta diff --git a/train_db_fixed.py b/train_db_fixed.py index c19b6dc..ff99d9f 100644 --- a/train_db_fixed.py +++ b/train_db_fixed.py @@ -13,6 +13,9 @@ # v12: stop train text encode, tqdm smoothing # v13: bug fix # v14: refactor to use model_util, add log prefix, support safetensors, support vae loading, keep vae in CPU to save the loaded vae +# v15: model_util update +# v16: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0 +# v17: add fp16 gradient training (experimental) import gc import time @@ -43,7 +46,7 @@ import model_util # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ # CLIP_ID_L14_336 = "openai/clip-vit-large-patch14-336" @@ -596,10 +599,6 @@ def replace_unet_cross_attn_to_memory_efficient(): out = rearrange(out, 'b h n d -> b n (h d)') - # diffusers 0.6.0 - if type(self.to_out) is torch.nn.Sequential: - return self.to_out(out) - # diffusers 0.7.0~ out = self.to_out[0](out) out = self.to_out[1](out) @@ -633,10 +632,6 @@ def replace_unet_cross_attn_to_xformers(): out = rearrange(out, 'b n h d -> b n (h d)', h=h) # out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - # diffusers 0.6.0 - if type(self.to_out) is torch.nn.Sequential: - return self.to_out(out) - # diffusers 0.7.0~ out = self.to_out[0](out) out = self.to_out[1](out) @@ -821,6 +816,19 @@ def train(args): accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir) + # accelerateの互換性問題を解決する + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False + + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) + # mixed precisionに対応した型を用意しておき適宜castする weight_dtype = torch.float32 if args.mixed_precision == "fp16": @@ -914,6 +922,13 @@ def train(args): lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) + # acceleratorがなんかよろしくやってくれるらしい unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler) @@ -921,6 +936,15 @@ def train(args): if not cache_latents: vae.to(accelerator.device, dtype=weight_dtype) + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + org_unscale_grads = accelerator.scaler._unscale_grads_ + + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + # resumeする if args.resume is not None: print(f"resume training from state: {args.resume}") @@ -946,8 +970,8 @@ def train(args): # v12で更新:clip_sample=Falseに # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる - # 既存の1.4/1.5/2.0はすべてschdulerのconfigは(クラス名を除いて)同じ - # よくソースを見たら学習時は関係ないや(;'∀')  + # 既存の1.4/1.5/2.0/2.1はすべてschdulerのconfigは(クラス名を除いて)同じ + # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀')  noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False) @@ -1006,31 +1030,8 @@ def train(args): if args.v_parameterization: # v-parameterization training - # こうしたい: - # target = noise_scheduler.get_v(latents, noise, timesteps) - - # StabilityAiのddpm.pyのコード: - # elif self.parameterization == "v": - # target = self.get_v(x_start, noise, t) - # ... - # def get_v(self, x, noise, t): - # return ( - # extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - - # extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x - # ) - - # scheduling_ddim.pyのコード: - # elif self.config.prediction_type == "v_prediction": - # pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - # # predict V - # model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - - # これでいいかな?: - alpha_prod_t = noise_scheduler.alphas_cumprod[timesteps] - beta_prod_t = 1 - alpha_prod_t - alpha_prod_t = torch.reshape(alpha_prod_t, (len(alpha_prod_t), 1, 1, 1)) # broadcastされないらしいのでreshape - beta_prod_t = torch.reshape(beta_prod_t, (len(beta_prod_t), 1, 1, 1)) - target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents + # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う + target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise @@ -1081,13 +1082,14 @@ def train(args): if use_stable_diffusion_format: os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(args.use_safetensors, epoch + 1)) - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet), + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype, vae) else: out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder), - accelerator.unwrap_model(unet), args.pretrained_model_name_or_path) + model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), + unwrap_model(unet), args.pretrained_model_name_or_path, + use_safetensors=args.use_safetensors) if args.save_state: print("saving state.") @@ -1095,8 +1097,8 @@ def train(args): is_main_process = accelerator.is_main_process if is_main_process: - unet = accelerator.unwrap_model(unet) - text_encoder = accelerator.unwrap_model(text_encoder) + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) accelerator.end_training() @@ -1118,7 +1120,8 @@ def train(args): print(f"save trained model as Diffusers to {args.output_dir}") out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME) os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path) + model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path, + use_safetensors=args.use_safetensors) print("model saved.") @@ -1147,9 +1150,9 @@ if __name__ == '__main__': parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ") parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format for StableDiffusion checkpoint / StableDiffusionのcheckpointをsafetensors形式で保存する") + help="use safetensors format to save / checkpoint、モデルをsafetensors形式で保存する") parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存します") + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -1191,6 +1194,7 @@ if __name__ == '__main__': help="enable gradient checkpointing / grandient checkpointingを有効にする") parser.add_argument("--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") parser.add_argument("--save_precision", type=str, default=None, choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)") parser.add_argument("--clip_skip", type=int, default=None,