diff --git a/finetune/prepare_buckets_latents_new.py b/finetune/prepare_buckets_latents_new.py new file mode 100644 index 0000000..d68ea37 --- /dev/null +++ b/finetune/prepare_buckets_latents_new.py @@ -0,0 +1,180 @@ +# このスクリプトのライセンスは、Apache License 2.0とします +# (c) 2022 Kohya S. @kohya_ss + +import argparse +import glob +import os +import json + +from tqdm import tqdm +import numpy as np +from diffusers import AutoencoderKL +from PIL import Image +import cv2 +import torch +from torchvision import transforms + +import library.model_util as model_util + +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + + +def get_latents(vae, images, weight_dtype): + img_tensors = [IMAGE_TRANSFORMS(image) for image in images] + img_tensors = torch.stack(img_tensors) + img_tensors = img_tensors.to(DEVICE, weight_dtype) + with torch.no_grad(): + latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() + return latents + + +def main(args): + # image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ + # glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) + # print(f"found {len(image_paths)} images.") + + + if os.path.exists(args.in_json): + print(f"loading existing metadata: {args.in_json}") + with open(args.in_json, "rt", encoding='utf-8') as f: + metadata = json.load(f) + else: + print(f"no metadata / メタデータファイルがありません: {args.in_json}") + return + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + vae = model_util.load_vae(args.model_name_or_path, weight_dtype) + vae.eval() + vae.to(DEVICE, dtype=weight_dtype) + + # bucketのサイズを計算する + max_reso = tuple([int(t) for t in args.max_resolution.split(',')]) + assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" + + bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions( + max_reso, args.min_bucket_reso, args.max_bucket_reso) + + # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する + bucket_aspect_ratios = np.array(bucket_aspect_ratios) + buckets_imgs = [[] for _ in range(len(bucket_resos))] + bucket_counts = [0 for _ in range(len(bucket_resos))] + img_ar_errors = [] + for i, image_path in enumerate(tqdm(metadata, smoothing=0.0)): + image_key = image_path + if image_key not in metadata: + metadata[image_key] = {} + + image = Image.open(image_path) + if image.mode != 'RGB': + image = image.convert("RGB") + + aspect_ratio = image.width / image.height + ar_errors = bucket_aspect_ratios - aspect_ratio + bucket_id = np.abs(ar_errors).argmin() + reso = bucket_resos[bucket_id] + ar_error = ar_errors[bucket_id] + img_ar_errors.append(abs(ar_error)) + + # どのサイズにリサイズするか→トリミングする方向で + if ar_error <= 0: # 横が長い→縦を合わせる + scale = reso[1] / image.height + else: + scale = reso[0] / image.width + + resized_size = (int(image.width * scale + .5), int(image.height * scale + .5)) + + # print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size, + # bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1]) + + assert resized_size[0] == reso[0] or resized_size[1] == reso[ + 1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}" + assert resized_size[0] >= reso[0] and resized_size[1] >= reso[ + 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}" + + # 画像をリサイズしてトリミングする + # PILにinter_areaがないのでcv2で…… + image = np.array(image) + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) + if resized_size[0] > reso[0]: + trim_size = resized_size[0] - reso[0] + image = image[:, trim_size//2:trim_size//2 + reso[0]] + elif resized_size[1] > reso[1]: + trim_size = resized_size[1] - reso[1] + image = image[trim_size//2:trim_size//2 + reso[1]] + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" + + # # debug + # cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1]) + + # バッチへ追加 + buckets_imgs[bucket_id].append((image_key, reso, image)) + bucket_counts[bucket_id] += 1 + metadata[image_key]['train_resolution'] = reso + + # バッチを推論するか判定して推論する + is_last = i == len(metadata) - 1 + for j in range(len(buckets_imgs)): + bucket = buckets_imgs[j] + if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: + latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype) + + for (image_key, reso, _), latent in zip(bucket, latents): + npz_file_name = os.path.splitext(os.path.basename(image_key))[0] + np.savez(os.path.join(os.path.dirname(image_key), npz_file_name), latent) + + # flip + if args.flip_aug: + latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない + + for (image_key, reso, _), latent in zip(bucket, latents): + npz_file_name = os.path.splitext(os.path.basename(image_key))[0] + np.savez(os.path.join(os.path.dirname(image_key), npz_file_name + '_flip'), latent) + + bucket.clear() + + for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)): + print(f"bucket {i} {reso}: {count}") + img_ar_errors = np.array(img_ar_errors) + print(f"mean ar error: {np.mean(img_ar_errors)}") + + # metadataを書き出して終わり + print(f"writing metadata: {args.out_json}") + with open(args.out_json, "wt", encoding='utf-8') as f: + json.dump(metadata, f, indent=2) + print("done!") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") + parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") + parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') + parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") + parser.add_argument("--max_resolution", type=str, default="512,512", + help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") + parser.add_argument("--mixed_precision", type=str, default="no", + choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") + parser.add_argument("--full_path", action="store_true", + help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)") + parser.add_argument("--flip_aug", action="store_true", + help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") + + args = parser.parse_args() + main(args) diff --git a/finetune_gui copy.py b/finetune_gui copy.py new file mode 100644 index 0000000..cb9f672 --- /dev/null +++ b/finetune_gui copy.py @@ -0,0 +1,789 @@ +import gradio as gr +import json +import math +import os +import subprocess +import pathlib +import shutil +import argparse +from library.common_gui import ( + get_folder_path, + get_file_path, + get_any_file_path, + get_saveasfile_path, +) +from library.utilities import utilities_tab + +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +document_symbol = '\U0001F4C4' # 📄 + + +def save_configuration( + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate, + lr_scheduler, + lr_warmup, + dataset_repeats, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + train_text_encoder, + create_caption, + create_buckets, + save_model_as, + caption_extension, + use_8bit_adam, + xformers, + clip_skip, +): + original_file_path = file_path + + save_as_bool = True if save_as.get('label') == 'True' else False + + if save_as_bool: + print('Save as...') + file_path = get_saveasfile_path(file_path) + else: + print('Save...') + if file_path == None or file_path == '': + file_path = get_saveasfile_path(file_path) + + # print(file_path) + + if file_path == None: + return original_file_path + + # Return the values of the variables as a dictionary + variables = { + 'pretrained_model_name_or_path': pretrained_model_name_or_path, + 'v2': v2, + 'v_parameterization': v_parameterization, + 'train_dir': train_dir, + 'image_folder': image_folder, + 'output_dir': output_dir, + 'logging_dir': logging_dir, + 'max_resolution': max_resolution, + 'min_bucket_reso': min_bucket_reso, + 'max_bucket_reso': max_bucket_reso, + 'batch_size': batch_size, + 'flip_aug': flip_aug, + 'caption_metadata_filename': caption_metadata_filename, + 'latent_metadata_filename': latent_metadata_filename, + 'full_path': full_path, + 'learning_rate': learning_rate, + 'lr_scheduler': lr_scheduler, + 'lr_warmup': lr_warmup, + 'dataset_repeats': dataset_repeats, + 'train_batch_size': train_batch_size, + 'epoch': epoch, + 'save_every_n_epochs': save_every_n_epochs, + 'mixed_precision': mixed_precision, + 'save_precision': save_precision, + 'seed': seed, + 'num_cpu_threads_per_process': num_cpu_threads_per_process, + 'train_text_encoder': train_text_encoder, + 'create_buckets': create_buckets, + 'create_caption': create_caption, + 'save_model_as': save_model_as, + 'caption_extension': caption_extension, + 'use_8bit_adam': use_8bit_adam, + 'xformers': xformers, + 'clip_skip': clip_skip, + } + + # Save the data to the selected file + with open(file_path, 'w') as file: + json.dump(variables, file) + + return file_path + + +def open_config_file( + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate, + lr_scheduler, + lr_warmup, + dataset_repeats, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + train_text_encoder, + create_caption, + create_buckets, + save_model_as, + caption_extension, + use_8bit_adam, + xformers, + clip_skip, +): + original_file_path = file_path + file_path = get_file_path(file_path) + + if file_path != '' and file_path != None: + print(file_path) + # load variables from JSON file + with open(file_path, 'r') as f: + my_data = json.load(f) + else: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data = {} + + # Return the values of the variables as a dictionary + return ( + file_path, + my_data.get( + 'pretrained_model_name_or_path', pretrained_model_name_or_path + ), + my_data.get('v2', v2), + my_data.get('v_parameterization', v_parameterization), + my_data.get('train_dir', train_dir), + my_data.get('image_folder', image_folder), + my_data.get('output_dir', output_dir), + my_data.get('logging_dir', logging_dir), + my_data.get('max_resolution', max_resolution), + my_data.get('min_bucket_reso', min_bucket_reso), + my_data.get('max_bucket_reso', max_bucket_reso), + my_data.get('batch_size', batch_size), + my_data.get('flip_aug', flip_aug), + my_data.get('caption_metadata_filename', caption_metadata_filename), + my_data.get('latent_metadata_filename', latent_metadata_filename), + my_data.get('full_path', full_path), + my_data.get('learning_rate', learning_rate), + my_data.get('lr_scheduler', lr_scheduler), + my_data.get('lr_warmup', lr_warmup), + my_data.get('dataset_repeats', dataset_repeats), + my_data.get('train_batch_size', train_batch_size), + my_data.get('epoch', epoch), + my_data.get('save_every_n_epochs', save_every_n_epochs), + my_data.get('mixed_precision', mixed_precision), + my_data.get('save_precision', save_precision), + my_data.get('seed', seed), + my_data.get( + 'num_cpu_threads_per_process', num_cpu_threads_per_process + ), + my_data.get('train_text_encoder', train_text_encoder), + my_data.get('create_buckets', create_buckets), + my_data.get('create_caption', create_caption), + my_data.get('save_model_as', save_model_as), + my_data.get('caption_extension', caption_extension), + my_data.get('use_8bit_adam', use_8bit_adam), + my_data.get('xformers', xformers), + my_data.get('clip_skip', clip_skip), + ) + + +def train_model( + pretrained_model_name_or_path, + v2, + v_parameterization, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate, + lr_scheduler, + lr_warmup, + dataset_repeats, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + train_text_encoder, + generate_caption_database, + generate_image_buckets, + save_model_as, + caption_extension, + use_8bit_adam, + xformers, + clip_skip, +): + def save_inference_file(output_dir, v2, v_parameterization): + # Copy inference model for v2 if required + if v2 and v_parameterization: + print(f'Saving v2-inference-v.yaml as {output_dir}/last.yaml') + shutil.copy( + f'./v2_inference/v2-inference-v.yaml', + f'{output_dir}/last.yaml', + ) + elif v2: + print(f'Saving v2-inference.yaml as {output_dir}/last.yaml') + shutil.copy( + f'./v2_inference/v2-inference.yaml', + f'{output_dir}/last.yaml', + ) + + # create caption json file + if generate_caption_database: + if not os.path.exists(train_dir): + os.mkdir(train_dir) + + for root, dirs, files in os.walk(image_folder): + for dir in dirs: + print(os.path.join(root, dir)) + + run_cmd = ( + f'./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py' + ) + if caption_extension == '': + run_cmd += f' --caption_extension=".txt"' + else: + run_cmd += f' --caption_extension={caption_extension}' + run_cmd += f' "{os.path.join(root, dir)}"' + run_cmd += f' "{train_dir}/{caption_metadata_filename}"' + if full_path: + run_cmd += f' --full_path' + + print(run_cmd) + + # Run the command + subprocess.run(run_cmd) + + # create images buckets + if generate_image_buckets: + run_cmd = ( + f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py' + ) + run_cmd += f' "crap"' + run_cmd += f' "{train_dir}/{caption_metadata_filename}"' + run_cmd += f' "{train_dir}/{latent_metadata_filename}"' + run_cmd += f' "{pretrained_model_name_or_path}"' + run_cmd += f' --batch_size={batch_size}' + run_cmd += f' --max_resolution={max_resolution}' + run_cmd += f' --min_bucket_reso={min_bucket_reso}' + run_cmd += f' --max_bucket_reso={max_bucket_reso}' + run_cmd += f' --mixed_precision={mixed_precision}' + if flip_aug: + run_cmd += f' --flip_aug' + if full_path: + run_cmd += f' --full_path' + + print(run_cmd) + + # Run the command + subprocess.run(run_cmd) + + image_num = 0 + for root, dirs, files in os.walk(image_folder): + for dir in dirs: + image_num += len( + [f for f in os.listdir(os.path.join(root, dir)) if f.endswith('.npz')] + ) + print(f'image_num = {image_num}') + + repeats = int(image_num) * int(dataset_repeats) + print(f'repeats = {str(repeats)}') + + # calculate max_train_steps + max_train_steps = int( + math.ceil(float(repeats) / int(train_batch_size) * int(epoch)) + ) + + # Divide by two because flip augmentation create two copied of the source images + if flip_aug: + max_train_steps = int(math.ceil(float(max_train_steps) / 2)) + + print(f'max_train_steps = {max_train_steps}') + + lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) + print(f'lr_warmup_steps = {lr_warmup_steps}') + + run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"' + if v2: + run_cmd += ' --v2' + if v_parameterization: + run_cmd += ' --v_parameterization' + if train_text_encoder: + run_cmd += ' --train_text_encoder' + if use_8bit_adam: + run_cmd += f' --use_8bit_adam' + if xformers: + run_cmd += f' --xformers' + run_cmd += ( + f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"' + ) + run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"' + run_cmd += f' --train_data_dir="{image_folder}"' + run_cmd += f' --output_dir="{output_dir}"' + if not logging_dir == '': + run_cmd += f' --logging_dir="{logging_dir}"' + run_cmd += f' --train_batch_size={train_batch_size}' + run_cmd += f' --dataset_repeats={dataset_repeats}' + run_cmd += f' --learning_rate={learning_rate}' + run_cmd += f' --lr_scheduler={lr_scheduler}' + run_cmd += f' --lr_warmup_steps={lr_warmup_steps}' + run_cmd += f' --max_train_steps={max_train_steps}' + run_cmd += f' --mixed_precision={mixed_precision}' + run_cmd += f' --save_every_n_epochs={save_every_n_epochs}' + run_cmd += f' --seed={seed}' + run_cmd += f' --save_precision={save_precision}' + if not save_model_as == 'same as source model': + run_cmd += f' --save_model_as={save_model_as}' + if int(clip_skip) > 1: + run_cmd += f' --clip_skip={str(clip_skip)}' + + print(run_cmd) + # Run the command + subprocess.run(run_cmd) + + # check if output_dir/last is a folder... therefore it is a diffuser model + last_dir = pathlib.Path(f'{output_dir}/last') + + if not last_dir.is_dir(): + # Copy inference model for v2 if required + save_inference_file(output_dir, v2, v_parameterization) + + +def set_pretrained_model_name_or_path_input(value, v2, v_parameterization): + # define a list of substrings to search for + substrings_v2 = [ + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', + ] + + # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list + if str(value) in substrings_v2: + print('SD v2 model detected. Setting --v2 parameter') + v2 = True + v_parameterization = False + + return value, v2, v_parameterization + + # define a list of substrings to search for v-objective + substrings_v_parameterization = [ + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', + ] + + # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list + if str(value) in substrings_v_parameterization: + print( + 'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization' + ) + v2 = True + v_parameterization = True + + return value, v2, v_parameterization + + # define a list of substrings to v1.x + substrings_v1_model = [ + 'CompVis/stable-diffusion-v1-4', + 'runwayml/stable-diffusion-v1-5', + ] + + if str(value) in substrings_v1_model: + v2 = False + v_parameterization = False + + return value, v2, v_parameterization + + if value == 'custom': + value = '' + v2 = False + v_parameterization = False + + return value, v2, v_parameterization + + +def remove_doublequote(file_path): + if file_path != None: + file_path = file_path.replace('"', '') + + return file_path + + +def UI(username, password): + + css = '' + + if os.path.exists('./style.css'): + with open(os.path.join('./style.css'), 'r', encoding='utf8') as file: + print('Load CSS...') + css += file.read() + '\n' + + interface = gr.Blocks(css=css) + + with interface: + with gr.Tab('Finetune'): + finetune_tab() + with gr.Tab('Utilities'): + utilities_tab(enable_dreambooth_tab=False) + + # Show the interface + if not username == '': + interface.launch(auth=(username, password)) + else: + interface.launch() + + +def finetune_tab(): + dummy_ft_true = gr.Label(value=True, visible=False) + dummy_ft_false = gr.Label(value=False, visible=False) + gr.Markdown('Train a custom model using kohya finetune python code...') + with gr.Accordion('Configuration file', open=False): + with gr.Row(): + button_open_config = gr.Button( + f'Open {folder_symbol}', elem_id='open_folder' + ) + button_save_config = gr.Button( + f'Save {save_style_symbol}', elem_id='open_folder' + ) + button_save_as_config = gr.Button( + f'Save as... {save_style_symbol}', + elem_id='open_folder', + ) + config_file_name = gr.Textbox( + label='', placeholder='type file path or use buttons...' + ) + config_file_name.change( + remove_doublequote, + inputs=[config_file_name], + outputs=[config_file_name], + ) + with gr.Tab('Source model'): + # Define the input elements + with gr.Row(): + pretrained_model_name_or_path_input = gr.Textbox( + label='Pretrained model name or path', + placeholder='enter the path to custom model or name of pretrained model', + ) + pretrained_model_name_or_path_file = gr.Button( + document_symbol, elem_id='open_folder_small' + ) + pretrained_model_name_or_path_file.click( + get_any_file_path, + inputs=pretrained_model_name_or_path_input, + outputs=pretrained_model_name_or_path_input, + ) + pretrained_model_name_or_path_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + pretrained_model_name_or_path_folder.click( + get_folder_path, + inputs=pretrained_model_name_or_path_input, + outputs=pretrained_model_name_or_path_input, + ) + model_list = gr.Dropdown( + label='(Optional) Model Quick Pick', + choices=[ + 'custom', + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', + 'runwayml/stable-diffusion-v1-5', + 'CompVis/stable-diffusion-v1-4', + ], + ) + save_model_as_dropdown = gr.Dropdown( + label='Save trained model as', + choices=[ + 'same as source model', + 'ckpt', + 'diffusers', + 'diffusers_safetensors', + 'safetensors', + ], + value='same as source model', + ) + + with gr.Row(): + v2_input = gr.Checkbox(label='v2', value=True) + v_parameterization_input = gr.Checkbox( + label='v_parameterization', value=False + ) + model_list.change( + set_pretrained_model_name_or_path_input, + inputs=[model_list, v2_input, v_parameterization_input], + outputs=[ + pretrained_model_name_or_path_input, + v2_input, + v_parameterization_input, + ], + ) + with gr.Tab('Folders'): + with gr.Row(): + train_dir_input = gr.Textbox( + label='Training config folder', + placeholder='folder where the training configuration files will be saved', + ) + train_dir_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + train_dir_folder.click(get_folder_path, outputs=train_dir_input) + + image_folder_input = gr.Textbox( + label='Training Image folder', + placeholder='folder where the training images are located', + ) + image_folder_input_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + image_folder_input_folder.click( + get_folder_path, outputs=image_folder_input + ) + with gr.Row(): + output_dir_input = gr.Textbox( + label='Output folder', + placeholder='folder where the model will be saved', + ) + output_dir_input_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + output_dir_input_folder.click( + get_folder_path, outputs=output_dir_input + ) + + logging_dir_input = gr.Textbox( + label='Logging folder', + placeholder='Optional: enable logging and output TensorBoard log to this folder', + ) + logging_dir_input_folder = gr.Button( + folder_symbol, elem_id='open_folder_small' + ) + logging_dir_input_folder.click( + get_folder_path, outputs=logging_dir_input + ) + train_dir_input.change( + remove_doublequote, + inputs=[train_dir_input], + outputs=[train_dir_input], + ) + image_folder_input.change( + remove_doublequote, + inputs=[image_folder_input], + outputs=[image_folder_input], + ) + output_dir_input.change( + remove_doublequote, + inputs=[output_dir_input], + outputs=[output_dir_input], + ) + with gr.Tab('Dataset preparation'): + with gr.Row(): + max_resolution_input = gr.Textbox( + label='Resolution (width,height)', value='512,512' + ) + min_bucket_reso = gr.Textbox( + label='Min bucket resolution', value='256' + ) + max_bucket_reso = gr.Textbox( + label='Max bucket resolution', value='1024' + ) + batch_size = gr.Textbox(label='Batch size', value='1') + with gr.Accordion('Advanced parameters', open=False): + with gr.Row(): + caption_metadata_filename = gr.Textbox( + label='Caption metadata filename', value='meta_cap.json' + ) + latent_metadata_filename = gr.Textbox( + label='Latent metadata filename', value='meta_lat.json' + ) + full_path = gr.Checkbox(label='Use full path', value=True) + flip_aug = gr.Checkbox(label='Flip augmentation', value=False) + with gr.Tab('Training parameters'): + with gr.Row(): + learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6) + lr_scheduler_input = gr.Dropdown( + label='LR Scheduler', + choices=[ + 'constant', + 'constant_with_warmup', + 'cosine', + 'cosine_with_restarts', + 'linear', + 'polynomial', + ], + value='constant', + ) + lr_warmup_input = gr.Textbox(label='LR warmup', value=0) + with gr.Row(): + dataset_repeats_input = gr.Textbox( + label='Dataset repeats', value=40 + ) + train_batch_size_input = gr.Slider( + minimum=1, + maximum=32, + label='Train batch size', + value=1, + step=1, + ) + epoch_input = gr.Textbox(label='Epoch', value=1) + save_every_n_epochs_input = gr.Textbox( + label='Save every N epochs', value=1 + ) + with gr.Row(): + mixed_precision_input = gr.Dropdown( + label='Mixed precision', + choices=[ + 'no', + 'fp16', + 'bf16', + ], + value='fp16', + ) + save_precision_input = gr.Dropdown( + label='Save precision', + choices=[ + 'float', + 'fp16', + 'bf16', + ], + value='fp16', + ) + num_cpu_threads_per_process_input = gr.Slider( + minimum=1, + maximum=os.cpu_count(), + step=1, + label='Number of CPU threads per process', + value=os.cpu_count(), + ) + seed_input = gr.Textbox(label='Seed', value=1234) + with gr.Row(): + caption_extention_input = gr.Textbox( + label='Caption Extension', + placeholder='(Optional) Extension for caption files. default: .txt', + ) + train_text_encoder_input = gr.Checkbox( + label='Train text encoder', value=True + ) + with gr.Accordion('Advanced parameters', open=False): + with gr.Row(): + use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) + xformers = gr.Checkbox(label='Use xformers', value=True) + clip_skip = gr.Slider( + label='Clip skip', value='1', minimum=1, maximum=12, step=1 + ) + with gr.Box(): + with gr.Row(): + create_caption = gr.Checkbox( + label='Generate caption metadata', value=True + ) + create_buckets = gr.Checkbox( + label='Generate image buckets metadata', value=True + ) + + button_run = gr.Button('Train model') + + settings_list = [ + pretrained_model_name_or_path_input, + v2_input, + v_parameterization_input, + train_dir_input, + image_folder_input, + output_dir_input, + logging_dir_input, + max_resolution_input, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate_input, + lr_scheduler_input, + lr_warmup_input, + dataset_repeats_input, + train_batch_size_input, + epoch_input, + save_every_n_epochs_input, + mixed_precision_input, + save_precision_input, + seed_input, + num_cpu_threads_per_process_input, + train_text_encoder_input, + create_caption, + create_buckets, + save_model_as_dropdown, + caption_extention_input, + use_8bit_adam, + xformers, + clip_skip, + ] + + button_run.click(train_model, inputs=settings_list) + + button_open_config.click( + open_config_file, + inputs=[config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, + ) + + button_save_config.click( + save_configuration, + inputs=[dummy_ft_false, config_file_name] + settings_list, + outputs=[config_file_name], + ) + + button_save_as_config.click( + save_configuration, + inputs=[dummy_ft_true, config_file_name] + settings_list, + outputs=[config_file_name], + ) + + +if __name__ == '__main__': + # torch.cuda.set_per_process_memory_fraction(0.48) + parser = argparse.ArgumentParser() + parser.add_argument( + '--username', type=str, default='', help='Username for authentication' + ) + parser.add_argument( + '--password', type=str, default='', help='Password for authentication' + ) + + args = parser.parse_args() + + UI(username=args.username, password=args.password) diff --git a/lora_gui.py b/lora_gui.py index 068b926..2f83d18 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -79,6 +79,7 @@ def save_configuration( gradient_accumulation_steps, mem_eff_attn, output_name, + model_list, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -120,32 +121,32 @@ def save_configuration( def open_configuration( file_path, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latent, + caption_extention, + enable_bucket, gradient_checkpointing, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, + full_fp16, + no_token_padding, + stop_text_encoder_training, + use_8bit_adam, + xformers, save_model_as_dropdown, shuffle_caption, save_state, @@ -161,27 +162,28 @@ def open_configuration( gradient_accumulation_steps, mem_eff_attn, output_name, + model_list, ): # Get list of function parameters and values parameters = list(locals().items()) - + original_file_path = file_path file_path = get_file_path(file_path) if not file_path == '' and not file_path == None: # load variables from JSON file with open(file_path, 'r') as f: - my_data_lora = json.load(f) + my_data = json.load(f) print("Loading config...") else: file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action - my_data_lora = {} - + my_data = {} + values = [file_path] for key, value in parameters: # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found if not key in ['file_path']: - values.append(my_data_lora.get(key, value)) + values.append(my_data.get(key, value)) return tuple(values) @@ -227,6 +229,7 @@ def train_model( gradient_accumulation_steps, mem_eff_attn, output_name, + model_list, ): if pretrained_model_name_or_path == '': msgbox('Source model information is missing') @@ -480,7 +483,7 @@ def lora_tab( with gr.Tab('Source model'): # Define the input elements with gr.Row(): - pretrained_model_name_or_path_input = gr.Textbox( + pretrained_model_name_or_path = gr.Textbox( label='Pretrained model name or path', placeholder='enter the path to custom model or name of pretrained model', ) @@ -489,15 +492,15 @@ def lora_tab( ) pretrained_model_name_or_path_file.click( get_any_file_path, - inputs=[pretrained_model_name_or_path_input], - outputs=pretrained_model_name_or_path_input, + inputs=[pretrained_model_name_or_path], + outputs=pretrained_model_name_or_path, ) pretrained_model_name_or_path_folder = gr.Button( folder_symbol, elem_id='open_folder_small' ) pretrained_model_name_or_path_folder.click( get_folder_path, - outputs=pretrained_model_name_or_path_input, + outputs=pretrained_model_name_or_path, ) model_list = gr.Dropdown( label='(Optional) Model Quick Pick', @@ -524,67 +527,67 @@ def lora_tab( ) with gr.Row(): - v2_input = gr.Checkbox(label='v2', value=True) - v_parameterization_input = gr.Checkbox( + v2 = gr.Checkbox(label='v2', value=True) + v_parameterization = gr.Checkbox( label='v_parameterization', value=False ) - pretrained_model_name_or_path_input.change( + pretrained_model_name_or_path.change( remove_doublequote, - inputs=[pretrained_model_name_or_path_input], - outputs=[pretrained_model_name_or_path_input], + inputs=[pretrained_model_name_or_path], + outputs=[pretrained_model_name_or_path], ) model_list.change( set_pretrained_model_name_or_path_input, - inputs=[model_list, v2_input, v_parameterization_input], + inputs=[model_list, v2, v_parameterization], outputs=[ - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, + pretrained_model_name_or_path, + v2, + v_parameterization, ], ) with gr.Tab('Folders'): with gr.Row(): - train_data_dir_input = gr.Textbox( + train_data_dir = gr.Textbox( label='Image folder', placeholder='Folder where the training folders containing the images are located', ) - train_data_dir_input_folder = gr.Button( + train_data_dir_folder = gr.Button( '📂', elem_id='open_folder_small' ) - train_data_dir_input_folder.click( - get_folder_path, outputs=train_data_dir_input + train_data_dir_folder.click( + get_folder_path, outputs=train_data_dir ) - reg_data_dir_input = gr.Textbox( + reg_data_dir = gr.Textbox( label='Regularisation folder', placeholder='(Optional) Folder where where the regularization folders containing the images are located', ) - reg_data_dir_input_folder = gr.Button( + reg_data_dir_folder = gr.Button( '📂', elem_id='open_folder_small' ) - reg_data_dir_input_folder.click( - get_folder_path, outputs=reg_data_dir_input + reg_data_dir_folder.click( + get_folder_path, outputs=reg_data_dir ) with gr.Row(): - output_dir_input = gr.Textbox( + output_dir = gr.Textbox( label='Output folder', placeholder='Folder to output trained model', ) - output_dir_input_folder = gr.Button( + output_dir_folder = gr.Button( '📂', elem_id='open_folder_small' ) - output_dir_input_folder.click( - get_folder_path, outputs=output_dir_input + output_dir_folder.click( + get_folder_path, outputs=output_dir ) - logging_dir_input = gr.Textbox( + logging_dir = gr.Textbox( label='Logging folder', placeholder='Optional: enable logging and output TensorBoard log to this folder', ) - logging_dir_input_folder = gr.Button( + logging_dir_folder = gr.Button( '📂', elem_id='open_folder_small' ) - logging_dir_input_folder.click( - get_folder_path, outputs=logging_dir_input + logging_dir_folder.click( + get_folder_path, outputs=logging_dir ) with gr.Row(): output_name = gr.Textbox( @@ -593,25 +596,25 @@ def lora_tab( value='last', interactive=True, ) - train_data_dir_input.change( + train_data_dir.change( remove_doublequote, - inputs=[train_data_dir_input], - outputs=[train_data_dir_input], + inputs=[train_data_dir], + outputs=[train_data_dir], ) - reg_data_dir_input.change( + reg_data_dir.change( remove_doublequote, - inputs=[reg_data_dir_input], - outputs=[reg_data_dir_input], + inputs=[reg_data_dir], + outputs=[reg_data_dir], ) - output_dir_input.change( + output_dir.change( remove_doublequote, - inputs=[output_dir_input], - outputs=[output_dir_input], + inputs=[output_dir], + outputs=[output_dir], ) - logging_dir_input.change( + logging_dir.change( remove_doublequote, - inputs=[logging_dir_input], - outputs=[logging_dir_input], + inputs=[logging_dir], + outputs=[logging_dir], ) with gr.Tab('Training parameters'): with gr.Row(): @@ -628,7 +631,7 @@ def lora_tab( outputs=lora_network_weights, ) with gr.Row(): - lr_scheduler_input = gr.Dropdown( + lr_scheduler = gr.Dropdown( label='LR Scheduler', choices=[ 'constant', @@ -640,7 +643,7 @@ def lora_tab( ], value='cosine', ) - lr_warmup_input = gr.Textbox(label='LR warmup (% of steps)', value=10) + lr_warmup = gr.Textbox(label='LR warmup (% of steps)', value=10) with gr.Row(): text_encoder_lr = gr.Textbox( label='Text Encoder learning rate', @@ -659,19 +662,19 @@ def lora_tab( interactive=True, ) with gr.Row(): - train_batch_size_input = gr.Slider( + train_batch_size = gr.Slider( minimum=1, maximum=32, label='Train batch size', value=1, step=1, ) - epoch_input = gr.Textbox(label='Epoch', value=1) - save_every_n_epochs_input = gr.Textbox( + epoch = gr.Textbox(label='Epoch', value=1) + save_every_n_epochs = gr.Textbox( label='Save every N epochs', value=1 ) with gr.Row(): - mixed_precision_input = gr.Dropdown( + mixed_precision = gr.Dropdown( label='Mixed precision', choices=[ 'no', @@ -680,7 +683,7 @@ def lora_tab( ], value='fp16', ) - save_precision_input = gr.Dropdown( + save_precision = gr.Dropdown( label='Save precision', choices=[ 'float', @@ -689,7 +692,7 @@ def lora_tab( ], value='fp16', ) - num_cpu_threads_per_process_input = gr.Slider( + num_cpu_threads_per_process = gr.Slider( minimum=1, maximum=os.cpu_count(), step=1, @@ -697,18 +700,18 @@ def lora_tab( value=os.cpu_count(), ) with gr.Row(): - seed_input = gr.Textbox(label='Seed', value=1234) - max_resolution_input = gr.Textbox( + seed = gr.Textbox(label='Seed', value=1234) + max_resolution = gr.Textbox( label='Max resolution', value='512,512', placeholder='512,512', ) with gr.Row(): - caption_extention_input = gr.Textbox( + caption_extention = gr.Textbox( label='Caption Extension', placeholder='(Optional) Extension for caption files. default: .caption', ) - stop_text_encoder_training_input = gr.Slider( + stop_text_encoder_training = gr.Slider( minimum=0, maximum=100, value=0, @@ -716,20 +719,20 @@ def lora_tab( label='Stop text encoder training', ) with gr.Row(): - enable_bucket_input = gr.Checkbox( + enable_bucket = gr.Checkbox( label='Enable buckets', value=True ) - cache_latent_input = gr.Checkbox(label='Cache latent', value=True) - use_8bit_adam_input = gr.Checkbox( + cache_latent = gr.Checkbox(label='Cache latent', value=True) + use_8bit_adam = gr.Checkbox( label='Use 8bit adam', value=True ) - xformers_input = gr.Checkbox(label='Use xformers', value=True) + xformers = gr.Checkbox(label='Use xformers', value=True) with gr.Accordion('Advanced Configuration', open=False): with gr.Row(): - full_fp16_input = gr.Checkbox( + full_fp16 = gr.Checkbox( label='Full fp16 training (experimental)', value=False ) - no_token_padding_input = gr.Checkbox( + no_token_padding = gr.Checkbox( label='No token padding', value=False ) @@ -754,7 +757,7 @@ def lora_tab( color_aug.change( color_aug_changed, inputs=[color_aug], - outputs=[cache_latent_input], + outputs=[cache_latent], ) clip_skip = gr.Slider( label='Clip skip', value='1', minimum=1, maximum=12, step=1 @@ -783,10 +786,10 @@ def lora_tab( 'This section provide Dreambooth tools to help setup your dataset...' ) gradio_dreambooth_folder_creation_tab( - train_data_dir_input=train_data_dir_input, - reg_data_dir_input=reg_data_dir_input, - output_dir_input=output_dir_input, - logging_dir_input=logging_dir_input, + train_data_dir_input=train_data_dir, + reg_data_dir_input=reg_data_dir, + output_dir_input=output_dir, + logging_dir_input=logging_dir, ) gradio_dataset_balancing_tab() gradio_merge_lora_tab() @@ -794,32 +797,32 @@ def lora_tab( button_run = gr.Button('Train model') settings_list = [ - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latent, + caption_extention, + enable_bucket, gradient_checkpointing, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, + full_fp16, + no_token_padding, + stop_text_encoder_training, + use_8bit_adam, + xformers, save_model_as_dropdown, shuffle_caption, save_state, @@ -835,6 +838,7 @@ def lora_tab( gradient_accumulation_steps, mem_eff_attn, output_name, + model_list, ] button_open_config.click( @@ -861,10 +865,10 @@ def lora_tab( ) return ( - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - logging_dir_input, + train_data_dir, + reg_data_dir, + output_dir, + logging_dir, )