From dc5afbb057c57f581f36f98d8c8f3cfac598bc53 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 9 Jan 2023 11:48:57 -0500 Subject: [PATCH] Move functions to common_gui Add model name support --- README.md | 3 + dreambooth_gui.py | 177 +---------- fine_tune.py | 681 +++++++++++++++++++++++++----------------- finetune_gui.py | 203 +++---------- library/common_gui.py | 80 ++++- lora_gui.py | 272 ++++------------- 6 files changed, 605 insertions(+), 811 deletions(-) diff --git a/README.md b/README.md index 5f5be6b..00afe58 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,9 @@ Once you have created the LoRA network you can generate images via auto1111 by i ## Change history +* 2023/01/10 (v20.1): + - Add support for `--output_name` to trainers + - Refactor code for easier maintenance * 2023/01/10 (v20.0): - Update code base to match latest kohys_ss code upgrade in https://github.com/kohya-ss/sd-scripts * 2023/01/09 (v19.4.3): diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 97eb609..f5a262f 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -18,6 +18,8 @@ from library.common_gui import ( get_any_file_path, get_saveasfile_path, color_aug_changed, + save_inference_file, + set_pretrained_model_name_or_path_input, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -102,45 +104,6 @@ def save_configuration( 'save_as', ] } - # variables = { - # 'pretrained_model_name_or_path': pretrained_model_name_or_path, - # 'v2': v2, - # 'v_parameterization': v_parameterization, - # 'logging_dir': logging_dir, - # 'train_data_dir': train_data_dir, - # 'reg_data_dir': reg_data_dir, - # 'output_dir': output_dir, - # 'max_resolution': max_resolution, - # 'learning_rate': learning_rate, - # 'lr_scheduler': lr_scheduler, - # 'lr_warmup': lr_warmup, - # 'train_batch_size': train_batch_size, - # 'epoch': epoch, - # 'save_every_n_epochs': save_every_n_epochs, - # 'mixed_precision': mixed_precision, - # 'save_precision': save_precision, - # 'seed': seed, - # 'num_cpu_threads_per_process': num_cpu_threads_per_process, - # 'cache_latent': cache_latent, - # 'caption_extention': caption_extention, - # 'enable_bucket': enable_bucket, - # 'gradient_checkpointing': gradient_checkpointing, - # 'full_fp16': full_fp16, - # 'no_token_padding': no_token_padding, - # 'stop_text_encoder_training': stop_text_encoder_training, - # 'use_8bit_adam': use_8bit_adam, - # 'xformers': xformers, - # 'save_model_as': save_model_as, - # 'shuffle_caption': shuffle_caption, - # 'save_state': save_state, - # 'resume': resume, - # 'prior_loss_weight': prior_loss_weight, - # 'color_aug': color_aug, - # 'flip_aug': flip_aug, - # 'clip_skip': clip_skip, - # 'vae': vae, - # 'output_name': output_name, - # } # Save the data to the selected file with open(file_path, 'w') as file: @@ -194,71 +157,24 @@ def open_configuration( original_file_path = file_path file_path = get_file_path(file_path) - # print(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file with open(file_path, 'r') as f: - my_data = json.load(f) + my_data_db = json.load(f) + print("Loading config...") else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action - my_data = {} + my_data_db = {} values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found if not key in ['file_path']: - values.append(my_data.get(key, value)) - # print(values) + values.append(my_data_db.get(key, value)) return tuple(values) - # Return the values of the variables as a dictionary - # return ( - # file_path, - # my_data.get( - # 'pretrained_model_name_or_path', pretrained_model_name_or_path - # ), - # my_data.get('v2', v2), - # my_data.get('v_parameterization', v_parameterization), - # my_data.get('logging_dir', logging_dir), - # my_data.get('train_data_dir', train_data_dir), - # my_data.get('reg_data_dir', reg_data_dir), - # my_data.get('output_dir', output_dir), - # my_data.get('max_resolution', max_resolution), - # my_data.get('learning_rate', learning_rate), - # my_data.get('lr_scheduler', lr_scheduler), - # my_data.get('lr_warmup', lr_warmup), - # my_data.get('train_batch_size', train_batch_size), - # my_data.get('epoch', epoch), - # my_data.get('save_every_n_epochs', save_every_n_epochs), - # my_data.get('mixed_precision', mixed_precision), - # my_data.get('save_precision', save_precision), - # my_data.get('seed', seed), - # my_data.get( - # 'num_cpu_threads_per_process', num_cpu_threads_per_process - # ), - # my_data.get('cache_latent', cache_latent), - # my_data.get('caption_extention', caption_extention), - # my_data.get('enable_bucket', enable_bucket), - # my_data.get('gradient_checkpointing', gradient_checkpointing), - # my_data.get('full_fp16', full_fp16), - # my_data.get('no_token_padding', no_token_padding), - # my_data.get('stop_text_encoder_training', stop_text_encoder_training), - # my_data.get('use_8bit_adam', use_8bit_adam), - # my_data.get('xformers', xformers), - # my_data.get('save_model_as', save_model_as), - # my_data.get('shuffle_caption', shuffle_caption), - # my_data.get('save_state', save_state), - # my_data.get('resume', resume), - # my_data.get('prior_loss_weight', prior_loss_weight), - # my_data.get('color_aug', color_aug), - # my_data.get('flip_aug', flip_aug), - # my_data.get('clip_skip', clip_skip), - # my_data.get('vae', vae), - # my_data.get('output_name', output_name), - # ) - - + def train_model( pretrained_model_name_or_path, v2, @@ -298,29 +214,6 @@ def train_model( vae, output_name, ): - def save_inference_file(output_dir, v2, v_parameterization, output_name): - # List all files in the directory - files = os.listdir(output_dir) - - # Iterate over the list of files - for file in files: - # Check if the file starts with the value of save_inference_file - if file.startswith(output_name): - # Copy the v2-inference-v.yaml file to the current file, with a .yaml extension - if v2 and v_parameterization: - print(f'Saving v2-inference-v.yaml as {output_dir}/{file}.yaml') - shutil.copy( - f'./v2_inference/v2-inference-v.yaml', - f'{output_dir}/{file}.yaml', - ) - elif v2: - print(f'Saving v2-inference.yaml as {output_dir}/{file}.yaml') - shutil.copy( - f'./v2_inference/v2-inference.yaml', - f'{output_dir}/{file}.yaml', - ) - - if pretrained_model_name_or_path == '': msgbox('Source model information is missing') return @@ -487,57 +380,6 @@ def train_model( save_inference_file(output_dir, v2, v_parameterization, output_name) -def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): - # define a list of substrings to search for - substrings_v2 = [ - 'stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base', - ] - - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list - if str(value) in substrings_v2: - print('SD v2 model detected. Setting --v2 parameter') - v2 = True - v_parameterization = False - - return value, v2, v_parameterization - - # define a list of substrings to search for v-objective - substrings_v_parameterization = [ - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2', - ] - - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list - if str(value) in substrings_v_parameterization: - print( - 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' - ) - v2 = True - v_parameterization = True - - return value, v2, v_parameterization - - # define a list of substrings to v1.x - substrings_v1_model = [ - 'CompVis/stable-diffusion-v1-4', - 'runwayml/stable-diffusion-v1-5', - ] - - if str(value) in substrings_v1_model: - v2 = False - v_parameterization = False - - return value, v2, v_parameterization - - if value == 'custom': - value = '' - v2 = False - v_parameterization = False - - return value, v2, v_parameterization - - def UI(username, password): css = '' @@ -593,11 +435,6 @@ def dreambooth_tab( placeholder="type the configuration file path or use the 'Open' button above to select it...", interactive=True, ) - # config_file_name.change( - # remove_doublequote, - # inputs=[config_file_name], - # outputs=[config_file_name], - # ) with gr.Tab('Source model'): # Define the input elements with gr.Row(): diff --git a/fine_tune.py b/fine_tune.py index 1a94870..d0ebd64 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -16,321 +16,456 @@ import library.train_util as train_util def collate_fn(examples): - return examples[0] + return examples[0] def train(args): - train_util.verify_training_args(args) - train_util.prepare_dataset_args(args, True) + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - cache_latents = args.cache_latents + cache_latents = args.cache_latents - if args.seed is not None: - set_seed(args.seed) # 乱数系列を初期化する + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する - tokenizer = train_util.load_tokenizer(args) + tokenizer = train_util.load_tokenizer(args) - train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, - tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, - args.dataset_repeats, args.debug_dataset) - train_dataset.make_buckets() + train_dataset = train_util.FineTuningDataset( + args.in_json, + args.train_batch_size, + args.train_data_dir, + tokenizer, + args.max_token_length, + args.shuffle_caption, + args.keep_tokens, + args.resolution, + args.enable_bucket, + args.min_bucket_reso, + args.max_bucket_reso, + args.flip_aug, + args.color_aug, + args.face_crop_aug_range, + args.random_crop, + args.dataset_repeats, + args.debug_dataset, + ) + train_dataset.make_buckets() - if args.debug_dataset: - train_util.debug_dataset(train_dataset) - return - if len(train_dataset) == 0: - print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") - return + if args.debug_dataset: + train_util.debug_dataset(train_dataset) + return + if len(train_dataset) == 0: + print( + 'No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。' + ) + return - # acceleratorを準備する - print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + # acceleratorを準備する + print('prepare accelerator') + accelerator, unwrap_model = train_util.prepare_accelerator(args) - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype, save_dtype = train_util.prepare_dtype(args) + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) - # モデルを読み込む - text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + # モデルを読み込む + ( + text_encoder, + vae, + unet, + load_stable_diffusion_format, + ) = train_util.load_target_model(args, weight_dtype) - # verify load/save model formats - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None - else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path - - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' - use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - - # Diffusers版のxformers使用フラグを設定する関数 - def set_diffusers_xformers_flag(model, valid): - # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう - # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) - # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか - # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) - - # Recursively walk through all the children. - # Any children which exposes the set_use_memory_efficient_attention_xformers method - # gets the message - def fn_recursive_set_mem_eff(module: torch.nn.Module): - if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid) - - for child in module.children(): - fn_recursive_set_mem_eff(child) - - fn_recursive_set_mem_eff(model) - - # モデルに xformers とか memory efficient attention を組み込む - if args.diffusers_xformers: - print("Use xformers by Diffusers") - set_diffusers_xformers_flag(unet, True) - else: - # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある - print("Disable Diffusers' xformers") - set_diffusers_xformers_flag(unet, False) - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - - # 学習を準備する - if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) - vae.requires_grad_(False) - vae.eval() - with torch.no_grad(): - train_dataset.cache_latents(vae) - vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # 学習を準備する:モデルを適切な状態にする - training_models = [] - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - training_models.append(unet) - - if args.train_text_encoder: - print("enable text encoder training") - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - training_models.append(text_encoder) - else: - text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) # text encoderは学習しない - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - text_encoder.train() # required for gradient_checkpointing + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None else: - text_encoder.eval() + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path - if not cache_latents: - vae.requires_grad_(False) - vae.eval() - vae.to(accelerator.device, dtype=weight_dtype) + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = ( + args.save_model_as.lower() == 'ckpt' + or args.save_model_as.lower() == 'safetensors' + ) + use_safetensors = args.use_safetensors or ( + 'safetensors' in args.save_model_as.lower() + ) - for m in training_models: - m.requires_grad_(True) - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params + # Diffusers版のxformers使用フラグを設定する関数 + def set_diffusers_xformers_flag(model, valid): + # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう + # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) + # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか + # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) - # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, 'set_use_memory_efficient_attention_xformers'): + module.set_use_memory_efficient_attention_xformers(valid) - # 8-bit Adamを使う - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") - print("use 8-bit Adam optimizer") - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW + for child in module.children(): + fn_recursive_set_mem_eff(child) - # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) + fn_recursive_set_mem_eff(model) - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 - train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + # モデルに xformers とか memory efficient attention を組み込む + if args.diffusers_xformers: + print('Use xformers by Diffusers') + set_diffusers_xformers_flag(unet, True) + else: + # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある + print("Disable Diffusers' xformers") + set_diffusers_xformers_flag(unet, False) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) - - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") - unet.to(weight_dtype) - text_encoder.to(weight_dtype) - - # acceleratorがなんかよろしくやってくれるらしい - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - if args.resume is not None: - print(f"resume training from state: {args.resume}") - accelerator.load_state(args.resume) - - # epoch数を計算する - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # 学習する - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num examples / サンプル数: {train_dataset.num_train_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") - global_step = 0 - - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000, clip_sample=False) - - if accelerator.is_main_process: - accelerator.init_trackers("finetuning") - - for epoch in range(num_train_epochs): - print(f"epoch {epoch+1}/{num_train_epochs}") - for m in training_models: - m.train() - - loss_total = 0 - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] + train_dataset.cache_latents(vae) + vae.to('cpu') + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() - with torch.set_grad_enabled(args.train_text_encoder): - # Get the text embedding for conditioning - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) + # 学習を準備する:モデルを適切な状態にする + training_models = [] + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) + if args.train_text_encoder: + print('enable text encoder training') + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + training_models.append(text_encoder) + else: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) # text encoderは学習しない + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + text_encoder.train() # required for gradient_checkpointing else: - target = noise + text_encoder.eval() - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) - accelerator.backward(loss) - if accelerator.sync_gradients: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) + for m in training_models: + m.requires_grad_(True) + params = [] + for m in training_models: + params.extend(m.parameters()) + params_to_optimize = params - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + # 学習に必要なクラスを準備する + print('prepare optimizer, data loader etc.') - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + # 8-bit Adamを使う + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + 'No bitsand bytes / bitsandbytesがインストールされていないようです' + ) + print('use 8-bit Adam optimizer') + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW - current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} - accelerator.log(logs, step=global_step) + # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 + optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) - loss_total += current_loss - avr_loss = loss_total / (step+1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=1, + shuffle=False, + collate_fn=collate_fn, + num_workers=n_workers, + ) - if global_step >= args.max_train_steps: - break + # lr schedulerを用意する + lr_scheduler = diffusers.optimization.get_scheduler( + args.lr_scheduler, + optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps + * args.gradient_accumulation_steps, + ) - if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch+1) + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == 'fp16' + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print('enable full fp16 training.') + unet.to(weight_dtype) + text_encoder.to(weight_dtype) - accelerator.wait_for_everyone() + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + ( + unet, + text_encoder, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) - if args.save_every_n_epochs is not None: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) - is_main_process = accelerator.is_main_process - if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) + # resumeする + if args.resume is not None: + print(f'resume training from state: {args.resume}') + accelerator.load_state(args.resume) - accelerator.end_training() + # epoch数を計算する + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + num_train_epochs = math.ceil( + args.max_train_steps / num_update_steps_per_epoch + ) - if args.save_state: - train_util.save_state_on_train_end(args, accelerator) + # 学習する + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) + print('running training / 学習開始') + print(f' num examples / サンプル数: {train_dataset.num_train_images}') + print(f' num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}') + print(f' num epochs / epoch数: {num_train_epochs}') + print(f' batch size per device / バッチサイズ: {args.train_batch_size}') + print( + f' total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}' + ) + print( + f' gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}' + ) + print(f' total optimization steps / 学習ステップ数: {args.max_train_steps}') - del accelerator # この後メモリを使うのでこれは消す + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc='steps', + ) + global_step = 0 - if is_main_process: - src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path - train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, global_step, text_encoder, unet, vae) - print("model saved.") + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule='scaled_linear', + num_train_timesteps=1000, + clip_sample=False, + ) + + if accelerator.is_main_process: + accelerator.init_trackers('finetuning') + + for epoch in range(num_train_epochs): + print(f'epoch {epoch+1}/{num_train_epochs}') + for m in training_models: + m.train() + + loss_total = 0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate( + training_models[0] + ): # 複数モデルに対応していない模様だがとりあえずこうしておく + with torch.no_grad(): + if 'latents' in batch and batch['latents'] is not None: + latents = batch['latents'].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode( + batch['images'].to(dtype=weight_dtype) + ).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(args.train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch['input_ids'].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, + input_ids, + tokenizer, + text_encoder, + None if not args.full_fp16 else weight_dtype, + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (b_size,), + device=latents.device, + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise( + latents, noise, timesteps + ) + + # Predict the noise residual + noise_pred = unet( + noisy_latents, timesteps, encoder_hidden_states + ).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity( + latents, noise, timesteps + ) + else: + target = noise + + loss = torch.nn.functional.mse_loss( + noise_pred.float(), target.float(), reduction='mean' + ) + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_( + params_to_clip, 1.0 + ) # args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = { + 'loss': current_loss, + 'lr': lr_scheduler.get_last_lr()[0], + } + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step + 1) + logs = {'loss': avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {'epoch_loss': loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + src_path = ( + src_stable_diffusion_ckpt + if save_stable_diffusion_format + else src_diffusers_model_path + ) + train_util.save_sd_model_on_epoch_end( + args, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + unwrap_model(text_encoder), + unwrap_model(unet), + vae, + ) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) + + accelerator.end_training() + + if args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = ( + src_stable_diffusion_ckpt + if save_stable_diffusion_format + else src_diffusers_model_path + ) + train_util.save_sd_model_on_train_end( + args, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + global_step, + text_encoder, + unet, + vae, + ) + print('model saved.') if __name__ == '__main__': - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser() - train_util.add_sd_models_arguments(parser) - train_util.add_dataset_arguments(parser, False, True) - train_util.add_training_arguments(parser, False) - train_util.add_sd_saving_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True) + train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) - parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers / Diffusersでxformersを使用する') - parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + '--diffusers_xformers', + action='store_true', + help='use xformers by diffusers / Diffusersでxformersを使用する', + ) + parser.add_argument( + '--train_text_encoder', + action='store_true', + help='train text encoder / text encoderも学習する', + ) - args = parser.parse_args() - train(args) + args = parser.parse_args() + train(args) diff --git a/finetune_gui.py b/finetune_gui.py index cd1f1ff..168b87e 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -11,6 +11,8 @@ from library.common_gui import ( get_file_path, get_any_file_path, get_saveasfile_path, + save_inference_file, + set_pretrained_model_name_or_path_input, ) from library.utilities import utilities_tab @@ -63,7 +65,11 @@ def save_configuration( gradient_accumulation_steps, mem_eff_attn, shuffle_caption, + output_name, ): + # Get list of function parameters and values + parameters = list(locals().items()) + original_file_path = file_path save_as_bool = True if save_as.get('label') == 'True' else False @@ -83,51 +89,18 @@ def save_configuration( # Return the values of the variables as a dictionary variables = { - 'pretrained_model_name_or_path': pretrained_model_name_or_path, - 'v2': v2, - 'v_parameterization': v_parameterization, - 'train_dir': train_dir, - 'image_folder': image_folder, - 'output_dir': output_dir, - 'logging_dir': logging_dir, - 'max_resolution': max_resolution, - 'min_bucket_reso': min_bucket_reso, - 'max_bucket_reso': max_bucket_reso, - 'batch_size': batch_size, - 'flip_aug': flip_aug, - 'caption_metadata_filename': caption_metadata_filename, - 'latent_metadata_filename': latent_metadata_filename, - 'full_path': full_path, - 'learning_rate': learning_rate, - 'lr_scheduler': lr_scheduler, - 'lr_warmup': lr_warmup, - 'dataset_repeats': dataset_repeats, - 'train_batch_size': train_batch_size, - 'epoch': epoch, - 'save_every_n_epochs': save_every_n_epochs, - 'mixed_precision': mixed_precision, - 'save_precision': save_precision, - 'seed': seed, - 'num_cpu_threads_per_process': num_cpu_threads_per_process, - 'train_text_encoder': train_text_encoder, - 'create_buckets': create_buckets, - 'create_caption': create_caption, - 'save_model_as': save_model_as, - 'caption_extension': caption_extension, - 'use_8bit_adam': use_8bit_adam, - 'xformers': xformers, - 'clip_skip': clip_skip, - 'save_state': save_state, - 'resume': resume, - 'gradient_checkpointing': gradient_checkpointing, - 'gradient_accumulation_steps': gradient_accumulation_steps, - 'mem_eff_attn': mem_eff_attn, - 'shuffle_caption': shuffle_caption, + name: value + for name, value in parameters # locals().items() + if name + not in [ + 'file_path', + 'save_as', + ] } # Save the data to the selected file with open(file_path, 'w') as file: - json.dump(variables, file) + json.dump(variables, file, indent=2) return file_path @@ -174,7 +147,11 @@ def open_config_file( gradient_accumulation_steps, mem_eff_attn, shuffle_caption, + output_name, ): + # Get list of function parameters and values + parameters = list(locals().items()) + original_file_path = file_path file_path = get_file_path(file_path) @@ -182,59 +159,18 @@ def open_config_file( print(f'Loading config file {file_path}') # load variables from JSON file with open(file_path, 'r') as f: - my_data = json.load(f) + my_data_ft = json.load(f) else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action - my_data = {} - - # Return the values of the variables as a dictionary - return ( - file_path, - my_data.get( - 'pretrained_model_name_or_path', pretrained_model_name_or_path - ), - my_data.get('v2', v2), - my_data.get('v_parameterization', v_parameterization), - my_data.get('train_dir', train_dir), - my_data.get('image_folder', image_folder), - my_data.get('output_dir', output_dir), - my_data.get('logging_dir', logging_dir), - my_data.get('max_resolution', max_resolution), - my_data.get('min_bucket_reso', min_bucket_reso), - my_data.get('max_bucket_reso', max_bucket_reso), - my_data.get('batch_size', batch_size), - my_data.get('flip_aug', flip_aug), - my_data.get('caption_metadata_filename', caption_metadata_filename), - my_data.get('latent_metadata_filename', latent_metadata_filename), - my_data.get('full_path', full_path), - my_data.get('learning_rate', learning_rate), - my_data.get('lr_scheduler', lr_scheduler), - my_data.get('lr_warmup', lr_warmup), - my_data.get('dataset_repeats', dataset_repeats), - my_data.get('train_batch_size', train_batch_size), - my_data.get('epoch', epoch), - my_data.get('save_every_n_epochs', save_every_n_epochs), - my_data.get('mixed_precision', mixed_precision), - my_data.get('save_precision', save_precision), - my_data.get('seed', seed), - my_data.get( - 'num_cpu_threads_per_process', num_cpu_threads_per_process - ), - my_data.get('train_text_encoder', train_text_encoder), - my_data.get('create_buckets', create_buckets), - my_data.get('create_caption', create_caption), - my_data.get('save_model_as', save_model_as), - my_data.get('caption_extension', caption_extension), - my_data.get('use_8bit_adam', use_8bit_adam), - my_data.get('xformers', xformers), - my_data.get('clip_skip', clip_skip), - my_data.get('save_state', save_state), - my_data.get('resume', resume), - my_data.get('gradient_checkpointing', gradient_checkpointing), - my_data.get('gradient_accumulation_steps', gradient_accumulation_steps), - my_data.get('mem_eff_attn', mem_eff_attn), - my_data.get('shuffle_caption', shuffle_caption), - ) + my_data_ft = {} + + values = [file_path] + for key, value in parameters: + # Set the value in the dictionary to the corresponding value in `my_data_ft`, or the default value if not found + if not key in ['file_path']: + values.append(my_data_ft.get(key, value)) + # print(values) + return tuple(values) def train_model( @@ -278,22 +214,8 @@ def train_model( gradient_accumulation_steps, mem_eff_attn, shuffle_caption, + output_name, ): - def save_inference_file(output_dir, v2, v_parameterization): - # Copy inference model for v2 if required - if v2 and v_parameterization: - print(f'Saving v2-inference-v.yaml as {output_dir}/last.yaml') - shutil.copy( - f'./v2_inference/v2-inference-v.yaml', - f'{output_dir}/last.yaml', - ) - elif v2: - print(f'Saving v2-inference.yaml as {output_dir}/last.yaml') - shutil.copy( - f'./v2_inference/v2-inference.yaml', - f'{output_dir}/last.yaml', - ) - # create caption json file if generate_caption_database: if not os.path.exists(train_dir): @@ -407,68 +329,19 @@ def train_model( run_cmd += ' --save_state' if not resume == '': run_cmd += f' --resume={resume}' + if not output_name == '': + run_cmd += f' --output_name="{output_name}"' print(run_cmd) # Run the command subprocess.run(run_cmd) # check if output_dir/last is a folder... therefore it is a diffuser model - last_dir = pathlib.Path(f'{output_dir}/last') + last_dir = pathlib.Path(f'{output_dir}/{output_name}') if not last_dir.is_dir(): # Copy inference model for v2 if required - save_inference_file(output_dir, v2, v_parameterization) - - -def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): - # define a list of substrings to search for - substrings_v2 = [ - 'stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base', - ] - - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list - if str(value) in substrings_v2: - print('SD v2 model detected. Setting --v2 parameter') - v2 = True - v_parameterization = False - - return value, v2, v_parameterization - - # define a list of substrings to search for v-objective - substrings_v_parameterization = [ - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2', - ] - - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list - if str(value) in substrings_v_parameterization: - print( - 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' - ) - v2 = True - v_parameterization = True - - return value, v2, v_parameterization - - # define a list of substrings to v1.x - substrings_v1_model = [ - 'CompVis/stable-diffusion-v1-4', - 'runwayml/stable-diffusion-v1-5', - ] - - if str(value) in substrings_v1_model: - v2 = False - v_parameterization = False - - return value, v2, v_parameterization - - if value == 'custom': - value = '' - v2 = False - v_parameterization = False - - return value, v2, v_parameterization + save_inference_file(output_dir, v2, v_parameterization, output_name) def remove_doublequote(file_path): @@ -610,7 +483,7 @@ def finetune_tab(): ) with gr.Row(): output_dir_input = gr.Textbox( - label='Output folder', + label='Model output folder', placeholder='folder where the model will be saved', ) output_dir_input_folder = gr.Button( @@ -630,6 +503,13 @@ def finetune_tab(): logging_dir_input_folder.click( get_folder_path, outputs=logging_dir_input ) + with gr.Row(): + output_name = gr.Textbox( + label='Model output name', + placeholder='Name of the model to output', + value='last', + interactive=True, + ) train_dir_input.change( remove_doublequote, inputs=[train_dir_input], @@ -814,6 +694,7 @@ def finetune_tab(): gradient_accumulation_steps, mem_eff_attn, shuffle_caption, + output_name, ] button_run.click(train_model, inputs=settings_list) diff --git a/library/common_gui.py b/library/common_gui.py index ff7b2c5..76c3a13 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -2,6 +2,7 @@ from tkinter import filedialog, Tk import os import gradio as gr from easygui import msgbox +import shutil def get_dir_and_file(file_path): dir_path, file_name = os.path.split(file_path) @@ -183,4 +184,81 @@ def color_aug_changed(color_aug): msgbox('Disabling "Cache latent" because "Color augmentation" has been selected...') return gr.Checkbox.update(value=False, interactive=False) else: - return gr.Checkbox.update(value=True, interactive=True) \ No newline at end of file + return gr.Checkbox.update(value=True, interactive=True) + +def save_inference_file(output_dir, v2, v_parameterization, output_name): + # List all files in the directory + files = os.listdir(output_dir) + + # Iterate over the list of files + for file in files: + # Check if the file starts with the value of output_name + if file.startswith(output_name): + # Check if it is a file or a directory + if os.path.isfile(os.path.join(output_dir, file)): + # Split the file name and extension + file_name, ext = os.path.splitext(file) + + # Copy the v2-inference-v.yaml file to the current file, with a .yaml extension + if v2 and v_parameterization: + print(f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml') + shutil.copy( + f'./v2_inference/v2-inference-v.yaml', + f'{output_dir}/{file_name}.yaml', + ) + elif v2: + print(f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml') + shutil.copy( + f'./v2_inference/v2-inference.yaml', + f'{output_dir}/{file_name}.yaml', + ) + +def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): + # define a list of substrings to search for + substrings_v2 = [ + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', + ] + + # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list + if str(value) in substrings_v2: + print('SD v2 model detected. Setting --v2 parameter') + v2 = True + v_parameterization = False + + return value, v2, v_parameterization + + # define a list of substrings to search for v-objective + substrings_v_parameterization = [ + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', + ] + + # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list + if str(value) in substrings_v_parameterization: + print( + 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' + ) + v2 = True + v_parameterization = True + + return value, v2, v_parameterization + + # define a list of substrings to v1.x + substrings_v1_model = [ + 'CompVis/stable-diffusion-v1-4', + 'runwayml/stable-diffusion-v1-5', + ] + + if str(value) in substrings_v1_model: + v2 = False + v_parameterization = False + + return value, v2, v_parameterization + + if value == 'custom': + value = '' + v2 = False + v_parameterization = False + + return value, v2, v_parameterization \ No newline at end of file diff --git a/lora_gui.py b/lora_gui.py index aa3e52d..068b926 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -18,6 +18,8 @@ from library.common_gui import ( get_any_file_path, get_saveasfile_path, color_aug_changed, + save_inference_file, + set_pretrained_model_name_or_path_input, ) from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -76,8 +78,11 @@ def save_configuration( clip_skip, gradient_accumulation_steps, mem_eff_attn, - # vae, + output_name, ): + # Get list of function parameters and values + parameters = list(locals().items()) + original_file_path = file_path save_as_bool = True if save_as.get('label') == 'True' else False @@ -97,85 +102,51 @@ def save_configuration( # Return the values of the variables as a dictionary variables = { - 'pretrained_model_name_or_path': pretrained_model_name_or_path, - 'v2': v2, - 'v_parameterization': v_parameterization, - 'logging_dir': logging_dir, - 'train_data_dir': train_data_dir, - 'reg_data_dir': reg_data_dir, - 'output_dir': output_dir, - 'max_resolution': max_resolution, - 'lr_scheduler': lr_scheduler, - 'lr_warmup': lr_warmup, - 'train_batch_size': train_batch_size, - 'epoch': epoch, - 'save_every_n_epochs': save_every_n_epochs, - 'mixed_precision': mixed_precision, - 'save_precision': save_precision, - 'seed': seed, - 'num_cpu_threads_per_process': num_cpu_threads_per_process, - 'cache_latent': cache_latent, - 'caption_extention': caption_extention, - 'enable_bucket': enable_bucket, - 'gradient_checkpointing': gradient_checkpointing, - 'full_fp16': full_fp16, - 'no_token_padding': no_token_padding, - 'stop_text_encoder_training': stop_text_encoder_training, - 'use_8bit_adam': use_8bit_adam, - 'xformers': xformers, - 'save_model_as': save_model_as, - 'shuffle_caption': shuffle_caption, - 'save_state': save_state, - 'resume': resume, - 'prior_loss_weight': prior_loss_weight, - 'text_encoder_lr': text_encoder_lr, - 'unet_lr': unet_lr, - 'network_dim': network_dim, - 'lora_network_weights': lora_network_weights, - 'color_aug': color_aug, - 'flip_aug': flip_aug, - 'clip_skip': clip_skip, - 'gradient_accumulation_steps': gradient_accumulation_steps, - 'mem_eff_attn': mem_eff_attn, - # 'vae': vae, + name: value + for name, value in parameters # locals().items() + if name + not in [ + 'file_path', + 'save_as', + ] } # Save the data to the selected file with open(file_path, 'w') as file: - json.dump(variables, file) + json.dump(variables, file, indent=2) return file_path def open_configuration( file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latent, - caption_extention, - enable_bucket, + pretrained_model_name_or_path_input, + v2_input, + v_parameterization_input, + logging_dir_input, + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + max_resolution_input, + lr_scheduler_input, + lr_warmup_input, + train_batch_size_input, + epoch_input, + save_every_n_epochs_input, + mixed_precision_input, + save_precision_input, + seed_input, + num_cpu_threads_per_process_input, + cache_latent_input, + caption_extention_input, + enable_bucket_input, gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - use_8bit_adam, - xformers, - save_model_as, + full_fp16_input, + no_token_padding_input, + stop_text_encoder_training_input, + use_8bit_adam_input, + xformers_input, + save_model_as_dropdown, shuffle_caption, save_state, resume, @@ -189,70 +160,29 @@ def open_configuration( clip_skip, gradient_accumulation_steps, mem_eff_attn, - # vae, + output_name, ): + # Get list of function parameters and values + parameters = list(locals().items()) original_file_path = file_path file_path = get_file_path(file_path) - # print(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file with open(file_path, 'r') as f: - my_data = json.load(f) + my_data_lora = json.load(f) + print("Loading config...") else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action - my_data = {} - - # Return the values of the variables as a dictionary - return ( - file_path, - my_data.get( - 'pretrained_model_name_or_path', pretrained_model_name_or_path - ), - my_data.get('v2', v2), - my_data.get('v_parameterization', v_parameterization), - my_data.get('logging_dir', logging_dir), - my_data.get('train_data_dir', train_data_dir), - my_data.get('reg_data_dir', reg_data_dir), - my_data.get('output_dir', output_dir), - my_data.get('max_resolution', max_resolution), - my_data.get('lr_scheduler', lr_scheduler), - my_data.get('lr_warmup', lr_warmup), - my_data.get('train_batch_size', train_batch_size), - my_data.get('epoch', epoch), - my_data.get('save_every_n_epochs', save_every_n_epochs), - my_data.get('mixed_precision', mixed_precision), - my_data.get('save_precision', save_precision), - my_data.get('seed', seed), - my_data.get( - 'num_cpu_threads_per_process', num_cpu_threads_per_process - ), - my_data.get('cache_latent', cache_latent), - my_data.get('caption_extention', caption_extention), - my_data.get('enable_bucket', enable_bucket), - my_data.get('gradient_checkpointing', gradient_checkpointing), - my_data.get('full_fp16', full_fp16), - my_data.get('no_token_padding', no_token_padding), - my_data.get('stop_text_encoder_training', stop_text_encoder_training), - my_data.get('use_8bit_adam', use_8bit_adam), - my_data.get('xformers', xformers), - my_data.get('save_model_as', save_model_as), - my_data.get('shuffle_caption', shuffle_caption), - my_data.get('save_state', save_state), - my_data.get('resume', resume), - my_data.get('prior_loss_weight', prior_loss_weight), - my_data.get('text_encoder_lr', text_encoder_lr), - my_data.get('unet_lr', unet_lr), - my_data.get('network_dim', network_dim), - my_data.get('lora_network_weights', lora_network_weights), - my_data.get('color_aug', color_aug), - my_data.get('flip_aug', flip_aug), - my_data.get('clip_skip', clip_skip), - my_data.get('gradient_accumulation_steps', gradient_accumulation_steps), - my_data.get('mem_eff_attn', mem_eff_attn), - # my_data.get('vae', vae), - ) + my_data_lora = {} + + values = [file_path] + for key, value in parameters: + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found + if not key in ['file_path']: + values.append(my_data_lora.get(key, value)) + return tuple(values) def train_model( @@ -296,23 +226,8 @@ def train_model( clip_skip, gradient_accumulation_steps, mem_eff_attn, - # vae, + output_name, ): - def save_inference_file(output_dir, v2, v_parameterization): - # Copy inference model for v2 if required - if v2 and v_parameterization: - print(f'Saving v2-inference-v.yaml as {output_dir}/last.yaml') - shutil.copy( - f'./v2_inference/v2-inference-v.yaml', - f'{output_dir}/last.yaml', - ) - elif v2: - print(f'Saving v2-inference.yaml as {output_dir}/last.yaml') - shutil.copy( - f'./v2_inference/v2-inference.yaml', - f'{output_dir}/last.yaml', - ) - if pretrained_model_name_or_path == '': msgbox('Source model information is missing') return @@ -379,17 +294,6 @@ def train_model( # Print the result print(f'Folder {folder}: {steps} steps') - # Print the result - # print(f"{total_steps} total steps") - - # if reg_data_dir == '': - # reg_factor = 1 - # else: - # print( - # 'Regularisation images are used... Will double the number of steps required...' - # ) - # reg_factor = 2 - # calculate max_train_steps max_train_steps = int( math.ceil( @@ -496,68 +400,19 @@ def train_model( run_cmd += f' --gradient_accumulation_steps={int(gradient_accumulation_steps)}' # if not vae == '': # run_cmd += f' --vae="{vae}"' + if not output_name == '': + run_cmd += f' --output_name="{output_name}"' print(run_cmd) # Run the command subprocess.run(run_cmd) # check if output_dir/last is a folder... therefore it is a diffuser model - last_dir = pathlib.Path(f'{output_dir}/last') + last_dir = pathlib.Path(f'{output_dir}/{output_name}') if not last_dir.is_dir(): # Copy inference model for v2 if required - save_inference_file(output_dir, v2, v_parameterization) - - -def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): - # define a list of substrings to search for - substrings_v2 = [ - 'stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base', - ] - - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list - if str(value) in substrings_v2: - print('SD v2 model detected. Setting --v2 parameter') - v2 = True - v_parameterization = False - - return value, v2, v_parameterization - - # define a list of substrings to search for v-objective - substrings_v_parameterization = [ - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2', - ] - - # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list - if str(value) in substrings_v_parameterization: - print( - 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' - ) - v2 = True - v_parameterization = True - - return value, v2, v_parameterization - - # define a list of substrings to v1.x - substrings_v1_model = [ - 'CompVis/stable-diffusion-v1-4', - 'runwayml/stable-diffusion-v1-5', - ] - - if str(value) in substrings_v1_model: - v2 = False - v_parameterization = False - - return value, v2, v_parameterization - - if value == 'custom': - value = '' - v2 = False - v_parameterization = False - - return value, v2, v_parameterization + save_inference_file(output_dir, v2, v_parameterization, output_name) def UI(username, password): @@ -731,6 +586,13 @@ def lora_tab( logging_dir_input_folder.click( get_folder_path, outputs=logging_dir_input ) + with gr.Row(): + output_name = gr.Textbox( + label='Model output name', + placeholder='Name of the model to output', + value='last', + interactive=True, + ) train_data_dir_input.change( remove_doublequote, inputs=[train_data_dir_input], @@ -766,7 +628,6 @@ def lora_tab( outputs=lora_network_weights, ) with gr.Row(): - # learning_rate_input = gr.Textbox(label='Learning rate', value=1e-4, visible=False) lr_scheduler_input = gr.Dropdown( label='LR Scheduler', choices=[ @@ -941,7 +802,6 @@ def lora_tab( reg_data_dir_input, output_dir_input, max_resolution_input, - # learning_rate_input, lr_scheduler_input, lr_warmup_input, train_batch_size_input, @@ -974,7 +834,7 @@ def lora_tab( clip_skip, gradient_accumulation_steps, mem_eff_attn, - # vae, + output_name, ] button_open_config.click(