Fix issue with lora model loading
This commit is contained in:
parent
dc5afbb057
commit
11fbc63440
180
finetune/prepare_buckets_latents_new.py
Normal file
180
finetune/prepare_buckets_latents_new.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||||
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import library.model_util as model_util
|
||||||
|
|
||||||
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
IMAGE_TRANSFORMS = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5], [0.5]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_latents(vae, images, weight_dtype):
|
||||||
|
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
||||||
|
img_tensors = torch.stack(img_tensors)
|
||||||
|
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
||||||
|
with torch.no_grad():
|
||||||
|
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||||
|
# glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
|
# print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
|
|
||||||
|
if os.path.exists(args.in_json):
|
||||||
|
print(f"loading existing metadata: {args.in_json}")
|
||||||
|
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
else:
|
||||||
|
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||||
|
return
|
||||||
|
|
||||||
|
weight_dtype = torch.float32
|
||||||
|
if args.mixed_precision == "fp16":
|
||||||
|
weight_dtype = torch.float16
|
||||||
|
elif args.mixed_precision == "bf16":
|
||||||
|
weight_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||||
|
vae.eval()
|
||||||
|
vae.to(DEVICE, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# bucketのサイズを計算する
|
||||||
|
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
||||||
|
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||||
|
|
||||||
|
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||||
|
max_reso, args.min_bucket_reso, args.max_bucket_reso)
|
||||||
|
|
||||||
|
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||||
|
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
||||||
|
buckets_imgs = [[] for _ in range(len(bucket_resos))]
|
||||||
|
bucket_counts = [0 for _ in range(len(bucket_resos))]
|
||||||
|
img_ar_errors = []
|
||||||
|
for i, image_path in enumerate(tqdm(metadata, smoothing=0.0)):
|
||||||
|
image_key = image_path
|
||||||
|
if image_key not in metadata:
|
||||||
|
metadata[image_key] = {}
|
||||||
|
|
||||||
|
image = Image.open(image_path)
|
||||||
|
if image.mode != 'RGB':
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
aspect_ratio = image.width / image.height
|
||||||
|
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||||
|
bucket_id = np.abs(ar_errors).argmin()
|
||||||
|
reso = bucket_resos[bucket_id]
|
||||||
|
ar_error = ar_errors[bucket_id]
|
||||||
|
img_ar_errors.append(abs(ar_error))
|
||||||
|
|
||||||
|
# どのサイズにリサイズするか→トリミングする方向で
|
||||||
|
if ar_error <= 0: # 横が長い→縦を合わせる
|
||||||
|
scale = reso[1] / image.height
|
||||||
|
else:
|
||||||
|
scale = reso[0] / image.width
|
||||||
|
|
||||||
|
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
|
||||||
|
|
||||||
|
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
|
||||||
|
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
|
||||||
|
|
||||||
|
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
||||||
|
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||||
|
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||||
|
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||||
|
|
||||||
|
# 画像をリサイズしてトリミングする
|
||||||
|
# PILにinter_areaがないのでcv2で……
|
||||||
|
image = np.array(image)
|
||||||
|
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
||||||
|
if resized_size[0] > reso[0]:
|
||||||
|
trim_size = resized_size[0] - reso[0]
|
||||||
|
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||||
|
elif resized_size[1] > reso[1]:
|
||||||
|
trim_size = resized_size[1] - reso[1]
|
||||||
|
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||||
|
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||||
|
|
||||||
|
# # debug
|
||||||
|
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
|
||||||
|
|
||||||
|
# バッチへ追加
|
||||||
|
buckets_imgs[bucket_id].append((image_key, reso, image))
|
||||||
|
bucket_counts[bucket_id] += 1
|
||||||
|
metadata[image_key]['train_resolution'] = reso
|
||||||
|
|
||||||
|
# バッチを推論するか判定して推論する
|
||||||
|
is_last = i == len(metadata) - 1
|
||||||
|
for j in range(len(buckets_imgs)):
|
||||||
|
bucket = buckets_imgs[j]
|
||||||
|
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||||
|
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
||||||
|
|
||||||
|
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||||
|
npz_file_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||||
|
np.savez(os.path.join(os.path.dirname(image_key), npz_file_name), latent)
|
||||||
|
|
||||||
|
# flip
|
||||||
|
if args.flip_aug:
|
||||||
|
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||||
|
|
||||||
|
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||||
|
npz_file_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||||
|
np.savez(os.path.join(os.path.dirname(image_key), npz_file_name + '_flip'), latent)
|
||||||
|
|
||||||
|
bucket.clear()
|
||||||
|
|
||||||
|
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
|
||||||
|
print(f"bucket {i} {reso}: {count}")
|
||||||
|
img_ar_errors = np.array(img_ar_errors)
|
||||||
|
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||||
|
|
||||||
|
# metadataを書き出して終わり
|
||||||
|
print(f"writing metadata: {args.out_json}")
|
||||||
|
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
|
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||||
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
|
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||||
|
parser.add_argument("--v2", action='store_true',
|
||||||
|
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
|
parser.add_argument("--max_resolution", type=str, default="512,512",
|
||||||
|
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||||
|
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||||
|
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||||
|
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||||
|
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||||
|
parser.add_argument("--full_path", action="store_true",
|
||||||
|
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||||
|
parser.add_argument("--flip_aug", action="store_true",
|
||||||
|
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
789
finetune_gui copy.py
Normal file
789
finetune_gui copy.py
Normal file
@ -0,0 +1,789 @@
|
|||||||
|
import gradio as gr
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import pathlib
|
||||||
|
import shutil
|
||||||
|
import argparse
|
||||||
|
from library.common_gui import (
|
||||||
|
get_folder_path,
|
||||||
|
get_file_path,
|
||||||
|
get_any_file_path,
|
||||||
|
get_saveasfile_path,
|
||||||
|
)
|
||||||
|
from library.utilities import utilities_tab
|
||||||
|
|
||||||
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
|
document_symbol = '\U0001F4C4' # 📄
|
||||||
|
|
||||||
|
|
||||||
|
def save_configuration(
|
||||||
|
save_as,
|
||||||
|
file_path,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
v2,
|
||||||
|
v_parameterization,
|
||||||
|
train_dir,
|
||||||
|
image_folder,
|
||||||
|
output_dir,
|
||||||
|
logging_dir,
|
||||||
|
max_resolution,
|
||||||
|
min_bucket_reso,
|
||||||
|
max_bucket_reso,
|
||||||
|
batch_size,
|
||||||
|
flip_aug,
|
||||||
|
caption_metadata_filename,
|
||||||
|
latent_metadata_filename,
|
||||||
|
full_path,
|
||||||
|
learning_rate,
|
||||||
|
lr_scheduler,
|
||||||
|
lr_warmup,
|
||||||
|
dataset_repeats,
|
||||||
|
train_batch_size,
|
||||||
|
epoch,
|
||||||
|
save_every_n_epochs,
|
||||||
|
mixed_precision,
|
||||||
|
save_precision,
|
||||||
|
seed,
|
||||||
|
num_cpu_threads_per_process,
|
||||||
|
train_text_encoder,
|
||||||
|
create_caption,
|
||||||
|
create_buckets,
|
||||||
|
save_model_as,
|
||||||
|
caption_extension,
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
|
clip_skip,
|
||||||
|
):
|
||||||
|
original_file_path = file_path
|
||||||
|
|
||||||
|
save_as_bool = True if save_as.get('label') == 'True' else False
|
||||||
|
|
||||||
|
if save_as_bool:
|
||||||
|
print('Save as...')
|
||||||
|
file_path = get_saveasfile_path(file_path)
|
||||||
|
else:
|
||||||
|
print('Save...')
|
||||||
|
if file_path == None or file_path == '':
|
||||||
|
file_path = get_saveasfile_path(file_path)
|
||||||
|
|
||||||
|
# print(file_path)
|
||||||
|
|
||||||
|
if file_path == None:
|
||||||
|
return original_file_path
|
||||||
|
|
||||||
|
# Return the values of the variables as a dictionary
|
||||||
|
variables = {
|
||||||
|
'pretrained_model_name_or_path': pretrained_model_name_or_path,
|
||||||
|
'v2': v2,
|
||||||
|
'v_parameterization': v_parameterization,
|
||||||
|
'train_dir': train_dir,
|
||||||
|
'image_folder': image_folder,
|
||||||
|
'output_dir': output_dir,
|
||||||
|
'logging_dir': logging_dir,
|
||||||
|
'max_resolution': max_resolution,
|
||||||
|
'min_bucket_reso': min_bucket_reso,
|
||||||
|
'max_bucket_reso': max_bucket_reso,
|
||||||
|
'batch_size': batch_size,
|
||||||
|
'flip_aug': flip_aug,
|
||||||
|
'caption_metadata_filename': caption_metadata_filename,
|
||||||
|
'latent_metadata_filename': latent_metadata_filename,
|
||||||
|
'full_path': full_path,
|
||||||
|
'learning_rate': learning_rate,
|
||||||
|
'lr_scheduler': lr_scheduler,
|
||||||
|
'lr_warmup': lr_warmup,
|
||||||
|
'dataset_repeats': dataset_repeats,
|
||||||
|
'train_batch_size': train_batch_size,
|
||||||
|
'epoch': epoch,
|
||||||
|
'save_every_n_epochs': save_every_n_epochs,
|
||||||
|
'mixed_precision': mixed_precision,
|
||||||
|
'save_precision': save_precision,
|
||||||
|
'seed': seed,
|
||||||
|
'num_cpu_threads_per_process': num_cpu_threads_per_process,
|
||||||
|
'train_text_encoder': train_text_encoder,
|
||||||
|
'create_buckets': create_buckets,
|
||||||
|
'create_caption': create_caption,
|
||||||
|
'save_model_as': save_model_as,
|
||||||
|
'caption_extension': caption_extension,
|
||||||
|
'use_8bit_adam': use_8bit_adam,
|
||||||
|
'xformers': xformers,
|
||||||
|
'clip_skip': clip_skip,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save the data to the selected file
|
||||||
|
with open(file_path, 'w') as file:
|
||||||
|
json.dump(variables, file)
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def open_config_file(
|
||||||
|
file_path,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
v2,
|
||||||
|
v_parameterization,
|
||||||
|
train_dir,
|
||||||
|
image_folder,
|
||||||
|
output_dir,
|
||||||
|
logging_dir,
|
||||||
|
max_resolution,
|
||||||
|
min_bucket_reso,
|
||||||
|
max_bucket_reso,
|
||||||
|
batch_size,
|
||||||
|
flip_aug,
|
||||||
|
caption_metadata_filename,
|
||||||
|
latent_metadata_filename,
|
||||||
|
full_path,
|
||||||
|
learning_rate,
|
||||||
|
lr_scheduler,
|
||||||
|
lr_warmup,
|
||||||
|
dataset_repeats,
|
||||||
|
train_batch_size,
|
||||||
|
epoch,
|
||||||
|
save_every_n_epochs,
|
||||||
|
mixed_precision,
|
||||||
|
save_precision,
|
||||||
|
seed,
|
||||||
|
num_cpu_threads_per_process,
|
||||||
|
train_text_encoder,
|
||||||
|
create_caption,
|
||||||
|
create_buckets,
|
||||||
|
save_model_as,
|
||||||
|
caption_extension,
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
|
clip_skip,
|
||||||
|
):
|
||||||
|
original_file_path = file_path
|
||||||
|
file_path = get_file_path(file_path)
|
||||||
|
|
||||||
|
if file_path != '' and file_path != None:
|
||||||
|
print(file_path)
|
||||||
|
# load variables from JSON file
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
my_data = json.load(f)
|
||||||
|
else:
|
||||||
|
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||||
|
my_data = {}
|
||||||
|
|
||||||
|
# Return the values of the variables as a dictionary
|
||||||
|
return (
|
||||||
|
file_path,
|
||||||
|
my_data.get(
|
||||||
|
'pretrained_model_name_or_path', pretrained_model_name_or_path
|
||||||
|
),
|
||||||
|
my_data.get('v2', v2),
|
||||||
|
my_data.get('v_parameterization', v_parameterization),
|
||||||
|
my_data.get('train_dir', train_dir),
|
||||||
|
my_data.get('image_folder', image_folder),
|
||||||
|
my_data.get('output_dir', output_dir),
|
||||||
|
my_data.get('logging_dir', logging_dir),
|
||||||
|
my_data.get('max_resolution', max_resolution),
|
||||||
|
my_data.get('min_bucket_reso', min_bucket_reso),
|
||||||
|
my_data.get('max_bucket_reso', max_bucket_reso),
|
||||||
|
my_data.get('batch_size', batch_size),
|
||||||
|
my_data.get('flip_aug', flip_aug),
|
||||||
|
my_data.get('caption_metadata_filename', caption_metadata_filename),
|
||||||
|
my_data.get('latent_metadata_filename', latent_metadata_filename),
|
||||||
|
my_data.get('full_path', full_path),
|
||||||
|
my_data.get('learning_rate', learning_rate),
|
||||||
|
my_data.get('lr_scheduler', lr_scheduler),
|
||||||
|
my_data.get('lr_warmup', lr_warmup),
|
||||||
|
my_data.get('dataset_repeats', dataset_repeats),
|
||||||
|
my_data.get('train_batch_size', train_batch_size),
|
||||||
|
my_data.get('epoch', epoch),
|
||||||
|
my_data.get('save_every_n_epochs', save_every_n_epochs),
|
||||||
|
my_data.get('mixed_precision', mixed_precision),
|
||||||
|
my_data.get('save_precision', save_precision),
|
||||||
|
my_data.get('seed', seed),
|
||||||
|
my_data.get(
|
||||||
|
'num_cpu_threads_per_process', num_cpu_threads_per_process
|
||||||
|
),
|
||||||
|
my_data.get('train_text_encoder', train_text_encoder),
|
||||||
|
my_data.get('create_buckets', create_buckets),
|
||||||
|
my_data.get('create_caption', create_caption),
|
||||||
|
my_data.get('save_model_as', save_model_as),
|
||||||
|
my_data.get('caption_extension', caption_extension),
|
||||||
|
my_data.get('use_8bit_adam', use_8bit_adam),
|
||||||
|
my_data.get('xformers', xformers),
|
||||||
|
my_data.get('clip_skip', clip_skip),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
v2,
|
||||||
|
v_parameterization,
|
||||||
|
train_dir,
|
||||||
|
image_folder,
|
||||||
|
output_dir,
|
||||||
|
logging_dir,
|
||||||
|
max_resolution,
|
||||||
|
min_bucket_reso,
|
||||||
|
max_bucket_reso,
|
||||||
|
batch_size,
|
||||||
|
flip_aug,
|
||||||
|
caption_metadata_filename,
|
||||||
|
latent_metadata_filename,
|
||||||
|
full_path,
|
||||||
|
learning_rate,
|
||||||
|
lr_scheduler,
|
||||||
|
lr_warmup,
|
||||||
|
dataset_repeats,
|
||||||
|
train_batch_size,
|
||||||
|
epoch,
|
||||||
|
save_every_n_epochs,
|
||||||
|
mixed_precision,
|
||||||
|
save_precision,
|
||||||
|
seed,
|
||||||
|
num_cpu_threads_per_process,
|
||||||
|
train_text_encoder,
|
||||||
|
generate_caption_database,
|
||||||
|
generate_image_buckets,
|
||||||
|
save_model_as,
|
||||||
|
caption_extension,
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
|
clip_skip,
|
||||||
|
):
|
||||||
|
def save_inference_file(output_dir, v2, v_parameterization):
|
||||||
|
# Copy inference model for v2 if required
|
||||||
|
if v2 and v_parameterization:
|
||||||
|
print(f'Saving v2-inference-v.yaml as {output_dir}/last.yaml')
|
||||||
|
shutil.copy(
|
||||||
|
f'./v2_inference/v2-inference-v.yaml',
|
||||||
|
f'{output_dir}/last.yaml',
|
||||||
|
)
|
||||||
|
elif v2:
|
||||||
|
print(f'Saving v2-inference.yaml as {output_dir}/last.yaml')
|
||||||
|
shutil.copy(
|
||||||
|
f'./v2_inference/v2-inference.yaml',
|
||||||
|
f'{output_dir}/last.yaml',
|
||||||
|
)
|
||||||
|
|
||||||
|
# create caption json file
|
||||||
|
if generate_caption_database:
|
||||||
|
if not os.path.exists(train_dir):
|
||||||
|
os.mkdir(train_dir)
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(image_folder):
|
||||||
|
for dir in dirs:
|
||||||
|
print(os.path.join(root, dir))
|
||||||
|
|
||||||
|
run_cmd = (
|
||||||
|
f'./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py'
|
||||||
|
)
|
||||||
|
if caption_extension == '':
|
||||||
|
run_cmd += f' --caption_extension=".txt"'
|
||||||
|
else:
|
||||||
|
run_cmd += f' --caption_extension={caption_extension}'
|
||||||
|
run_cmd += f' "{os.path.join(root, dir)}"'
|
||||||
|
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
|
||||||
|
if full_path:
|
||||||
|
run_cmd += f' --full_path'
|
||||||
|
|
||||||
|
print(run_cmd)
|
||||||
|
|
||||||
|
# Run the command
|
||||||
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
|
# create images buckets
|
||||||
|
if generate_image_buckets:
|
||||||
|
run_cmd = (
|
||||||
|
f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py'
|
||||||
|
)
|
||||||
|
run_cmd += f' "crap"'
|
||||||
|
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
|
||||||
|
run_cmd += f' "{train_dir}/{latent_metadata_filename}"'
|
||||||
|
run_cmd += f' "{pretrained_model_name_or_path}"'
|
||||||
|
run_cmd += f' --batch_size={batch_size}'
|
||||||
|
run_cmd += f' --max_resolution={max_resolution}'
|
||||||
|
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
|
||||||
|
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
|
||||||
|
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||||
|
if flip_aug:
|
||||||
|
run_cmd += f' --flip_aug'
|
||||||
|
if full_path:
|
||||||
|
run_cmd += f' --full_path'
|
||||||
|
|
||||||
|
print(run_cmd)
|
||||||
|
|
||||||
|
# Run the command
|
||||||
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
|
image_num = 0
|
||||||
|
for root, dirs, files in os.walk(image_folder):
|
||||||
|
for dir in dirs:
|
||||||
|
image_num += len(
|
||||||
|
[f for f in os.listdir(os.path.join(root, dir)) if f.endswith('.npz')]
|
||||||
|
)
|
||||||
|
print(f'image_num = {image_num}')
|
||||||
|
|
||||||
|
repeats = int(image_num) * int(dataset_repeats)
|
||||||
|
print(f'repeats = {str(repeats)}')
|
||||||
|
|
||||||
|
# calculate max_train_steps
|
||||||
|
max_train_steps = int(
|
||||||
|
math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Divide by two because flip augmentation create two copied of the source images
|
||||||
|
if flip_aug:
|
||||||
|
max_train_steps = int(math.ceil(float(max_train_steps) / 2))
|
||||||
|
|
||||||
|
print(f'max_train_steps = {max_train_steps}')
|
||||||
|
|
||||||
|
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
||||||
|
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
||||||
|
|
||||||
|
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
|
||||||
|
if v2:
|
||||||
|
run_cmd += ' --v2'
|
||||||
|
if v_parameterization:
|
||||||
|
run_cmd += ' --v_parameterization'
|
||||||
|
if train_text_encoder:
|
||||||
|
run_cmd += ' --train_text_encoder'
|
||||||
|
if use_8bit_adam:
|
||||||
|
run_cmd += f' --use_8bit_adam'
|
||||||
|
if xformers:
|
||||||
|
run_cmd += f' --xformers'
|
||||||
|
run_cmd += (
|
||||||
|
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
||||||
|
)
|
||||||
|
run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"'
|
||||||
|
run_cmd += f' --train_data_dir="{image_folder}"'
|
||||||
|
run_cmd += f' --output_dir="{output_dir}"'
|
||||||
|
if not logging_dir == '':
|
||||||
|
run_cmd += f' --logging_dir="{logging_dir}"'
|
||||||
|
run_cmd += f' --train_batch_size={train_batch_size}'
|
||||||
|
run_cmd += f' --dataset_repeats={dataset_repeats}'
|
||||||
|
run_cmd += f' --learning_rate={learning_rate}'
|
||||||
|
run_cmd += f' --lr_scheduler={lr_scheduler}'
|
||||||
|
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
|
||||||
|
run_cmd += f' --max_train_steps={max_train_steps}'
|
||||||
|
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||||
|
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
|
||||||
|
run_cmd += f' --seed={seed}'
|
||||||
|
run_cmd += f' --save_precision={save_precision}'
|
||||||
|
if not save_model_as == 'same as source model':
|
||||||
|
run_cmd += f' --save_model_as={save_model_as}'
|
||||||
|
if int(clip_skip) > 1:
|
||||||
|
run_cmd += f' --clip_skip={str(clip_skip)}'
|
||||||
|
|
||||||
|
print(run_cmd)
|
||||||
|
# Run the command
|
||||||
|
subprocess.run(run_cmd)
|
||||||
|
|
||||||
|
# check if output_dir/last is a folder... therefore it is a diffuser model
|
||||||
|
last_dir = pathlib.Path(f'{output_dir}/last')
|
||||||
|
|
||||||
|
if not last_dir.is_dir():
|
||||||
|
# Copy inference model for v2 if required
|
||||||
|
save_inference_file(output_dir, v2, v_parameterization)
|
||||||
|
|
||||||
|
|
||||||
|
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
||||||
|
# define a list of substrings to search for
|
||||||
|
substrings_v2 = [
|
||||||
|
'stabilityai/stable-diffusion-2-1-base',
|
||||||
|
'stabilityai/stable-diffusion-2-base',
|
||||||
|
]
|
||||||
|
|
||||||
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
||||||
|
if str(value) in substrings_v2:
|
||||||
|
print('SD v2 model detected. Setting --v2 parameter')
|
||||||
|
v2 = True
|
||||||
|
v_parameterization = False
|
||||||
|
|
||||||
|
return value, v2, v_parameterization
|
||||||
|
|
||||||
|
# define a list of substrings to search for v-objective
|
||||||
|
substrings_v_parameterization = [
|
||||||
|
'stabilityai/stable-diffusion-2-1',
|
||||||
|
'stabilityai/stable-diffusion-2',
|
||||||
|
]
|
||||||
|
|
||||||
|
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
||||||
|
if str(value) in substrings_v_parameterization:
|
||||||
|
print(
|
||||||
|
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
|
||||||
|
)
|
||||||
|
v2 = True
|
||||||
|
v_parameterization = True
|
||||||
|
|
||||||
|
return value, v2, v_parameterization
|
||||||
|
|
||||||
|
# define a list of substrings to v1.x
|
||||||
|
substrings_v1_model = [
|
||||||
|
'CompVis/stable-diffusion-v1-4',
|
||||||
|
'runwayml/stable-diffusion-v1-5',
|
||||||
|
]
|
||||||
|
|
||||||
|
if str(value) in substrings_v1_model:
|
||||||
|
v2 = False
|
||||||
|
v_parameterization = False
|
||||||
|
|
||||||
|
return value, v2, v_parameterization
|
||||||
|
|
||||||
|
if value == 'custom':
|
||||||
|
value = ''
|
||||||
|
v2 = False
|
||||||
|
v_parameterization = False
|
||||||
|
|
||||||
|
return value, v2, v_parameterization
|
||||||
|
|
||||||
|
|
||||||
|
def remove_doublequote(file_path):
|
||||||
|
if file_path != None:
|
||||||
|
file_path = file_path.replace('"', '')
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def UI(username, password):
|
||||||
|
|
||||||
|
css = ''
|
||||||
|
|
||||||
|
if os.path.exists('./style.css'):
|
||||||
|
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||||
|
print('Load CSS...')
|
||||||
|
css += file.read() + '\n'
|
||||||
|
|
||||||
|
interface = gr.Blocks(css=css)
|
||||||
|
|
||||||
|
with interface:
|
||||||
|
with gr.Tab('Finetune'):
|
||||||
|
finetune_tab()
|
||||||
|
with gr.Tab('Utilities'):
|
||||||
|
utilities_tab(enable_dreambooth_tab=False)
|
||||||
|
|
||||||
|
# Show the interface
|
||||||
|
if not username == '':
|
||||||
|
interface.launch(auth=(username, password))
|
||||||
|
else:
|
||||||
|
interface.launch()
|
||||||
|
|
||||||
|
|
||||||
|
def finetune_tab():
|
||||||
|
dummy_ft_true = gr.Label(value=True, visible=False)
|
||||||
|
dummy_ft_false = gr.Label(value=False, visible=False)
|
||||||
|
gr.Markdown('Train a custom model using kohya finetune python code...')
|
||||||
|
with gr.Accordion('Configuration file', open=False):
|
||||||
|
with gr.Row():
|
||||||
|
button_open_config = gr.Button(
|
||||||
|
f'Open {folder_symbol}', elem_id='open_folder'
|
||||||
|
)
|
||||||
|
button_save_config = gr.Button(
|
||||||
|
f'Save {save_style_symbol}', elem_id='open_folder'
|
||||||
|
)
|
||||||
|
button_save_as_config = gr.Button(
|
||||||
|
f'Save as... {save_style_symbol}',
|
||||||
|
elem_id='open_folder',
|
||||||
|
)
|
||||||
|
config_file_name = gr.Textbox(
|
||||||
|
label='', placeholder='type file path or use buttons...'
|
||||||
|
)
|
||||||
|
config_file_name.change(
|
||||||
|
remove_doublequote,
|
||||||
|
inputs=[config_file_name],
|
||||||
|
outputs=[config_file_name],
|
||||||
|
)
|
||||||
|
with gr.Tab('Source model'):
|
||||||
|
# Define the input elements
|
||||||
|
with gr.Row():
|
||||||
|
pretrained_model_name_or_path_input = gr.Textbox(
|
||||||
|
label='Pretrained model name or path',
|
||||||
|
placeholder='enter the path to custom model or name of pretrained model',
|
||||||
|
)
|
||||||
|
pretrained_model_name_or_path_file = gr.Button(
|
||||||
|
document_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
pretrained_model_name_or_path_file.click(
|
||||||
|
get_any_file_path,
|
||||||
|
inputs=pretrained_model_name_or_path_input,
|
||||||
|
outputs=pretrained_model_name_or_path_input,
|
||||||
|
)
|
||||||
|
pretrained_model_name_or_path_folder = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
pretrained_model_name_or_path_folder.click(
|
||||||
|
get_folder_path,
|
||||||
|
inputs=pretrained_model_name_or_path_input,
|
||||||
|
outputs=pretrained_model_name_or_path_input,
|
||||||
|
)
|
||||||
|
model_list = gr.Dropdown(
|
||||||
|
label='(Optional) Model Quick Pick',
|
||||||
|
choices=[
|
||||||
|
'custom',
|
||||||
|
'stabilityai/stable-diffusion-2-1-base',
|
||||||
|
'stabilityai/stable-diffusion-2-base',
|
||||||
|
'stabilityai/stable-diffusion-2-1',
|
||||||
|
'stabilityai/stable-diffusion-2',
|
||||||
|
'runwayml/stable-diffusion-v1-5',
|
||||||
|
'CompVis/stable-diffusion-v1-4',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
save_model_as_dropdown = gr.Dropdown(
|
||||||
|
label='Save trained model as',
|
||||||
|
choices=[
|
||||||
|
'same as source model',
|
||||||
|
'ckpt',
|
||||||
|
'diffusers',
|
||||||
|
'diffusers_safetensors',
|
||||||
|
'safetensors',
|
||||||
|
],
|
||||||
|
value='same as source model',
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
v2_input = gr.Checkbox(label='v2', value=True)
|
||||||
|
v_parameterization_input = gr.Checkbox(
|
||||||
|
label='v_parameterization', value=False
|
||||||
|
)
|
||||||
|
model_list.change(
|
||||||
|
set_pretrained_model_name_or_path_input,
|
||||||
|
inputs=[model_list, v2_input, v_parameterization_input],
|
||||||
|
outputs=[
|
||||||
|
pretrained_model_name_or_path_input,
|
||||||
|
v2_input,
|
||||||
|
v_parameterization_input,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
with gr.Tab('Folders'):
|
||||||
|
with gr.Row():
|
||||||
|
train_dir_input = gr.Textbox(
|
||||||
|
label='Training config folder',
|
||||||
|
placeholder='folder where the training configuration files will be saved',
|
||||||
|
)
|
||||||
|
train_dir_folder = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
train_dir_folder.click(get_folder_path, outputs=train_dir_input)
|
||||||
|
|
||||||
|
image_folder_input = gr.Textbox(
|
||||||
|
label='Training Image folder',
|
||||||
|
placeholder='folder where the training images are located',
|
||||||
|
)
|
||||||
|
image_folder_input_folder = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
image_folder_input_folder.click(
|
||||||
|
get_folder_path, outputs=image_folder_input
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
output_dir_input = gr.Textbox(
|
||||||
|
label='Output folder',
|
||||||
|
placeholder='folder where the model will be saved',
|
||||||
|
)
|
||||||
|
output_dir_input_folder = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
output_dir_input_folder.click(
|
||||||
|
get_folder_path, outputs=output_dir_input
|
||||||
|
)
|
||||||
|
|
||||||
|
logging_dir_input = gr.Textbox(
|
||||||
|
label='Logging folder',
|
||||||
|
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||||
|
)
|
||||||
|
logging_dir_input_folder = gr.Button(
|
||||||
|
folder_symbol, elem_id='open_folder_small'
|
||||||
|
)
|
||||||
|
logging_dir_input_folder.click(
|
||||||
|
get_folder_path, outputs=logging_dir_input
|
||||||
|
)
|
||||||
|
train_dir_input.change(
|
||||||
|
remove_doublequote,
|
||||||
|
inputs=[train_dir_input],
|
||||||
|
outputs=[train_dir_input],
|
||||||
|
)
|
||||||
|
image_folder_input.change(
|
||||||
|
remove_doublequote,
|
||||||
|
inputs=[image_folder_input],
|
||||||
|
outputs=[image_folder_input],
|
||||||
|
)
|
||||||
|
output_dir_input.change(
|
||||||
|
remove_doublequote,
|
||||||
|
inputs=[output_dir_input],
|
||||||
|
outputs=[output_dir_input],
|
||||||
|
)
|
||||||
|
with gr.Tab('Dataset preparation'):
|
||||||
|
with gr.Row():
|
||||||
|
max_resolution_input = gr.Textbox(
|
||||||
|
label='Resolution (width,height)', value='512,512'
|
||||||
|
)
|
||||||
|
min_bucket_reso = gr.Textbox(
|
||||||
|
label='Min bucket resolution', value='256'
|
||||||
|
)
|
||||||
|
max_bucket_reso = gr.Textbox(
|
||||||
|
label='Max bucket resolution', value='1024'
|
||||||
|
)
|
||||||
|
batch_size = gr.Textbox(label='Batch size', value='1')
|
||||||
|
with gr.Accordion('Advanced parameters', open=False):
|
||||||
|
with gr.Row():
|
||||||
|
caption_metadata_filename = gr.Textbox(
|
||||||
|
label='Caption metadata filename', value='meta_cap.json'
|
||||||
|
)
|
||||||
|
latent_metadata_filename = gr.Textbox(
|
||||||
|
label='Latent metadata filename', value='meta_lat.json'
|
||||||
|
)
|
||||||
|
full_path = gr.Checkbox(label='Use full path', value=True)
|
||||||
|
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||||
|
with gr.Tab('Training parameters'):
|
||||||
|
with gr.Row():
|
||||||
|
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
||||||
|
lr_scheduler_input = gr.Dropdown(
|
||||||
|
label='LR Scheduler',
|
||||||
|
choices=[
|
||||||
|
'constant',
|
||||||
|
'constant_with_warmup',
|
||||||
|
'cosine',
|
||||||
|
'cosine_with_restarts',
|
||||||
|
'linear',
|
||||||
|
'polynomial',
|
||||||
|
],
|
||||||
|
value='constant',
|
||||||
|
)
|
||||||
|
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
|
||||||
|
with gr.Row():
|
||||||
|
dataset_repeats_input = gr.Textbox(
|
||||||
|
label='Dataset repeats', value=40
|
||||||
|
)
|
||||||
|
train_batch_size_input = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=32,
|
||||||
|
label='Train batch size',
|
||||||
|
value=1,
|
||||||
|
step=1,
|
||||||
|
)
|
||||||
|
epoch_input = gr.Textbox(label='Epoch', value=1)
|
||||||
|
save_every_n_epochs_input = gr.Textbox(
|
||||||
|
label='Save every N epochs', value=1
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
mixed_precision_input = gr.Dropdown(
|
||||||
|
label='Mixed precision',
|
||||||
|
choices=[
|
||||||
|
'no',
|
||||||
|
'fp16',
|
||||||
|
'bf16',
|
||||||
|
],
|
||||||
|
value='fp16',
|
||||||
|
)
|
||||||
|
save_precision_input = gr.Dropdown(
|
||||||
|
label='Save precision',
|
||||||
|
choices=[
|
||||||
|
'float',
|
||||||
|
'fp16',
|
||||||
|
'bf16',
|
||||||
|
],
|
||||||
|
value='fp16',
|
||||||
|
)
|
||||||
|
num_cpu_threads_per_process_input = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=os.cpu_count(),
|
||||||
|
step=1,
|
||||||
|
label='Number of CPU threads per process',
|
||||||
|
value=os.cpu_count(),
|
||||||
|
)
|
||||||
|
seed_input = gr.Textbox(label='Seed', value=1234)
|
||||||
|
with gr.Row():
|
||||||
|
caption_extention_input = gr.Textbox(
|
||||||
|
label='Caption Extension',
|
||||||
|
placeholder='(Optional) Extension for caption files. default: .txt',
|
||||||
|
)
|
||||||
|
train_text_encoder_input = gr.Checkbox(
|
||||||
|
label='Train text encoder', value=True
|
||||||
|
)
|
||||||
|
with gr.Accordion('Advanced parameters', open=False):
|
||||||
|
with gr.Row():
|
||||||
|
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||||
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
|
clip_skip = gr.Slider(
|
||||||
|
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
||||||
|
)
|
||||||
|
with gr.Box():
|
||||||
|
with gr.Row():
|
||||||
|
create_caption = gr.Checkbox(
|
||||||
|
label='Generate caption metadata', value=True
|
||||||
|
)
|
||||||
|
create_buckets = gr.Checkbox(
|
||||||
|
label='Generate image buckets metadata', value=True
|
||||||
|
)
|
||||||
|
|
||||||
|
button_run = gr.Button('Train model')
|
||||||
|
|
||||||
|
settings_list = [
|
||||||
|
pretrained_model_name_or_path_input,
|
||||||
|
v2_input,
|
||||||
|
v_parameterization_input,
|
||||||
|
train_dir_input,
|
||||||
|
image_folder_input,
|
||||||
|
output_dir_input,
|
||||||
|
logging_dir_input,
|
||||||
|
max_resolution_input,
|
||||||
|
min_bucket_reso,
|
||||||
|
max_bucket_reso,
|
||||||
|
batch_size,
|
||||||
|
flip_aug,
|
||||||
|
caption_metadata_filename,
|
||||||
|
latent_metadata_filename,
|
||||||
|
full_path,
|
||||||
|
learning_rate_input,
|
||||||
|
lr_scheduler_input,
|
||||||
|
lr_warmup_input,
|
||||||
|
dataset_repeats_input,
|
||||||
|
train_batch_size_input,
|
||||||
|
epoch_input,
|
||||||
|
save_every_n_epochs_input,
|
||||||
|
mixed_precision_input,
|
||||||
|
save_precision_input,
|
||||||
|
seed_input,
|
||||||
|
num_cpu_threads_per_process_input,
|
||||||
|
train_text_encoder_input,
|
||||||
|
create_caption,
|
||||||
|
create_buckets,
|
||||||
|
save_model_as_dropdown,
|
||||||
|
caption_extention_input,
|
||||||
|
use_8bit_adam,
|
||||||
|
xformers,
|
||||||
|
clip_skip,
|
||||||
|
]
|
||||||
|
|
||||||
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
|
||||||
|
button_open_config.click(
|
||||||
|
open_config_file,
|
||||||
|
inputs=[config_file_name] + settings_list,
|
||||||
|
outputs=[config_file_name] + settings_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
button_save_config.click(
|
||||||
|
save_configuration,
|
||||||
|
inputs=[dummy_ft_false, config_file_name] + settings_list,
|
||||||
|
outputs=[config_file_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
button_save_as_config.click(
|
||||||
|
save_configuration,
|
||||||
|
inputs=[dummy_ft_true, config_file_name] + settings_list,
|
||||||
|
outputs=[config_file_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--username', type=str, default='', help='Username for authentication'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--password', type=str, default='', help='Password for authentication'
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
UI(username=args.username, password=args.password)
|
246
lora_gui.py
246
lora_gui.py
@ -79,6 +79,7 @@ def save_configuration(
|
|||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
output_name,
|
output_name,
|
||||||
|
model_list,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -120,32 +121,32 @@ def save_configuration(
|
|||||||
|
|
||||||
def open_configuration(
|
def open_configuration(
|
||||||
file_path,
|
file_path,
|
||||||
pretrained_model_name_or_path_input,
|
pretrained_model_name_or_path,
|
||||||
v2_input,
|
v2,
|
||||||
v_parameterization_input,
|
v_parameterization,
|
||||||
logging_dir_input,
|
logging_dir,
|
||||||
train_data_dir_input,
|
train_data_dir,
|
||||||
reg_data_dir_input,
|
reg_data_dir,
|
||||||
output_dir_input,
|
output_dir,
|
||||||
max_resolution_input,
|
max_resolution,
|
||||||
lr_scheduler_input,
|
lr_scheduler,
|
||||||
lr_warmup_input,
|
lr_warmup,
|
||||||
train_batch_size_input,
|
train_batch_size,
|
||||||
epoch_input,
|
epoch,
|
||||||
save_every_n_epochs_input,
|
save_every_n_epochs,
|
||||||
mixed_precision_input,
|
mixed_precision,
|
||||||
save_precision_input,
|
save_precision,
|
||||||
seed_input,
|
seed,
|
||||||
num_cpu_threads_per_process_input,
|
num_cpu_threads_per_process,
|
||||||
cache_latent_input,
|
cache_latent,
|
||||||
caption_extention_input,
|
caption_extention,
|
||||||
enable_bucket_input,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16_input,
|
full_fp16,
|
||||||
no_token_padding_input,
|
no_token_padding,
|
||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam,
|
||||||
xformers_input,
|
xformers,
|
||||||
save_model_as_dropdown,
|
save_model_as_dropdown,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
@ -161,6 +162,7 @@ def open_configuration(
|
|||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
output_name,
|
output_name,
|
||||||
|
model_list,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -171,17 +173,17 @@ def open_configuration(
|
|||||||
if not file_path == '' and not file_path == None:
|
if not file_path == '' and not file_path == None:
|
||||||
# load variables from JSON file
|
# load variables from JSON file
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
my_data_lora = json.load(f)
|
my_data = json.load(f)
|
||||||
print("Loading config...")
|
print("Loading config...")
|
||||||
else:
|
else:
|
||||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||||
my_data_lora = {}
|
my_data = {}
|
||||||
|
|
||||||
values = [file_path]
|
values = [file_path]
|
||||||
for key, value in parameters:
|
for key, value in parameters:
|
||||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||||
if not key in ['file_path']:
|
if not key in ['file_path']:
|
||||||
values.append(my_data_lora.get(key, value))
|
values.append(my_data.get(key, value))
|
||||||
return tuple(values)
|
return tuple(values)
|
||||||
|
|
||||||
|
|
||||||
@ -227,6 +229,7 @@ def train_model(
|
|||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
output_name,
|
output_name,
|
||||||
|
model_list,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -480,7 +483,7 @@ def lora_tab(
|
|||||||
with gr.Tab('Source model'):
|
with gr.Tab('Source model'):
|
||||||
# Define the input elements
|
# Define the input elements
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
pretrained_model_name_or_path_input = gr.Textbox(
|
pretrained_model_name_or_path = gr.Textbox(
|
||||||
label='Pretrained model name or path',
|
label='Pretrained model name or path',
|
||||||
placeholder='enter the path to custom model or name of pretrained model',
|
placeholder='enter the path to custom model or name of pretrained model',
|
||||||
)
|
)
|
||||||
@ -489,15 +492,15 @@ def lora_tab(
|
|||||||
)
|
)
|
||||||
pretrained_model_name_or_path_file.click(
|
pretrained_model_name_or_path_file.click(
|
||||||
get_any_file_path,
|
get_any_file_path,
|
||||||
inputs=[pretrained_model_name_or_path_input],
|
inputs=[pretrained_model_name_or_path],
|
||||||
outputs=pretrained_model_name_or_path_input,
|
outputs=pretrained_model_name_or_path,
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_folder = gr.Button(
|
pretrained_model_name_or_path_folder = gr.Button(
|
||||||
folder_symbol, elem_id='open_folder_small'
|
folder_symbol, elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_folder.click(
|
pretrained_model_name_or_path_folder.click(
|
||||||
get_folder_path,
|
get_folder_path,
|
||||||
outputs=pretrained_model_name_or_path_input,
|
outputs=pretrained_model_name_or_path,
|
||||||
)
|
)
|
||||||
model_list = gr.Dropdown(
|
model_list = gr.Dropdown(
|
||||||
label='(Optional) Model Quick Pick',
|
label='(Optional) Model Quick Pick',
|
||||||
@ -524,67 +527,67 @@ def lora_tab(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
v2_input = gr.Checkbox(label='v2', value=True)
|
v2 = gr.Checkbox(label='v2', value=True)
|
||||||
v_parameterization_input = gr.Checkbox(
|
v_parameterization = gr.Checkbox(
|
||||||
label='v_parameterization', value=False
|
label='v_parameterization', value=False
|
||||||
)
|
)
|
||||||
pretrained_model_name_or_path_input.change(
|
pretrained_model_name_or_path.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[pretrained_model_name_or_path_input],
|
inputs=[pretrained_model_name_or_path],
|
||||||
outputs=[pretrained_model_name_or_path_input],
|
outputs=[pretrained_model_name_or_path],
|
||||||
)
|
)
|
||||||
model_list.change(
|
model_list.change(
|
||||||
set_pretrained_model_name_or_path_input,
|
set_pretrained_model_name_or_path_input,
|
||||||
inputs=[model_list, v2_input, v_parameterization_input],
|
inputs=[model_list, v2, v_parameterization],
|
||||||
outputs=[
|
outputs=[
|
||||||
pretrained_model_name_or_path_input,
|
pretrained_model_name_or_path,
|
||||||
v2_input,
|
v2,
|
||||||
v_parameterization_input,
|
v_parameterization,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Tab('Folders'):
|
with gr.Tab('Folders'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_data_dir_input = gr.Textbox(
|
train_data_dir = gr.Textbox(
|
||||||
label='Image folder',
|
label='Image folder',
|
||||||
placeholder='Folder where the training folders containing the images are located',
|
placeholder='Folder where the training folders containing the images are located',
|
||||||
)
|
)
|
||||||
train_data_dir_input_folder = gr.Button(
|
train_data_dir_folder = gr.Button(
|
||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
train_data_dir_input_folder.click(
|
train_data_dir_folder.click(
|
||||||
get_folder_path, outputs=train_data_dir_input
|
get_folder_path, outputs=train_data_dir
|
||||||
)
|
)
|
||||||
reg_data_dir_input = gr.Textbox(
|
reg_data_dir = gr.Textbox(
|
||||||
label='Regularisation folder',
|
label='Regularisation folder',
|
||||||
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||||
)
|
)
|
||||||
reg_data_dir_input_folder = gr.Button(
|
reg_data_dir_folder = gr.Button(
|
||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
reg_data_dir_input_folder.click(
|
reg_data_dir_folder.click(
|
||||||
get_folder_path, outputs=reg_data_dir_input
|
get_folder_path, outputs=reg_data_dir
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
output_dir_input = gr.Textbox(
|
output_dir = gr.Textbox(
|
||||||
label='Output folder',
|
label='Output folder',
|
||||||
placeholder='Folder to output trained model',
|
placeholder='Folder to output trained model',
|
||||||
)
|
)
|
||||||
output_dir_input_folder = gr.Button(
|
output_dir_folder = gr.Button(
|
||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
output_dir_input_folder.click(
|
output_dir_folder.click(
|
||||||
get_folder_path, outputs=output_dir_input
|
get_folder_path, outputs=output_dir
|
||||||
)
|
)
|
||||||
logging_dir_input = gr.Textbox(
|
logging_dir = gr.Textbox(
|
||||||
label='Logging folder',
|
label='Logging folder',
|
||||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||||
)
|
)
|
||||||
logging_dir_input_folder = gr.Button(
|
logging_dir_folder = gr.Button(
|
||||||
'📂', elem_id='open_folder_small'
|
'📂', elem_id='open_folder_small'
|
||||||
)
|
)
|
||||||
logging_dir_input_folder.click(
|
logging_dir_folder.click(
|
||||||
get_folder_path, outputs=logging_dir_input
|
get_folder_path, outputs=logging_dir
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
output_name = gr.Textbox(
|
output_name = gr.Textbox(
|
||||||
@ -593,25 +596,25 @@ def lora_tab(
|
|||||||
value='last',
|
value='last',
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
train_data_dir_input.change(
|
train_data_dir.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[train_data_dir_input],
|
inputs=[train_data_dir],
|
||||||
outputs=[train_data_dir_input],
|
outputs=[train_data_dir],
|
||||||
)
|
)
|
||||||
reg_data_dir_input.change(
|
reg_data_dir.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[reg_data_dir_input],
|
inputs=[reg_data_dir],
|
||||||
outputs=[reg_data_dir_input],
|
outputs=[reg_data_dir],
|
||||||
)
|
)
|
||||||
output_dir_input.change(
|
output_dir.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[output_dir_input],
|
inputs=[output_dir],
|
||||||
outputs=[output_dir_input],
|
outputs=[output_dir],
|
||||||
)
|
)
|
||||||
logging_dir_input.change(
|
logging_dir.change(
|
||||||
remove_doublequote,
|
remove_doublequote,
|
||||||
inputs=[logging_dir_input],
|
inputs=[logging_dir],
|
||||||
outputs=[logging_dir_input],
|
outputs=[logging_dir],
|
||||||
)
|
)
|
||||||
with gr.Tab('Training parameters'):
|
with gr.Tab('Training parameters'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -628,7 +631,7 @@ def lora_tab(
|
|||||||
outputs=lora_network_weights,
|
outputs=lora_network_weights,
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lr_scheduler_input = gr.Dropdown(
|
lr_scheduler = gr.Dropdown(
|
||||||
label='LR Scheduler',
|
label='LR Scheduler',
|
||||||
choices=[
|
choices=[
|
||||||
'constant',
|
'constant',
|
||||||
@ -640,7 +643,7 @@ def lora_tab(
|
|||||||
],
|
],
|
||||||
value='cosine',
|
value='cosine',
|
||||||
)
|
)
|
||||||
lr_warmup_input = gr.Textbox(label='LR warmup (% of steps)', value=10)
|
lr_warmup = gr.Textbox(label='LR warmup (% of steps)', value=10)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
text_encoder_lr = gr.Textbox(
|
text_encoder_lr = gr.Textbox(
|
||||||
label='Text Encoder learning rate',
|
label='Text Encoder learning rate',
|
||||||
@ -659,19 +662,19 @@ def lora_tab(
|
|||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_batch_size_input = gr.Slider(
|
train_batch_size = gr.Slider(
|
||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=32,
|
maximum=32,
|
||||||
label='Train batch size',
|
label='Train batch size',
|
||||||
value=1,
|
value=1,
|
||||||
step=1,
|
step=1,
|
||||||
)
|
)
|
||||||
epoch_input = gr.Textbox(label='Epoch', value=1)
|
epoch = gr.Textbox(label='Epoch', value=1)
|
||||||
save_every_n_epochs_input = gr.Textbox(
|
save_every_n_epochs = gr.Textbox(
|
||||||
label='Save every N epochs', value=1
|
label='Save every N epochs', value=1
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
mixed_precision_input = gr.Dropdown(
|
mixed_precision = gr.Dropdown(
|
||||||
label='Mixed precision',
|
label='Mixed precision',
|
||||||
choices=[
|
choices=[
|
||||||
'no',
|
'no',
|
||||||
@ -680,7 +683,7 @@ def lora_tab(
|
|||||||
],
|
],
|
||||||
value='fp16',
|
value='fp16',
|
||||||
)
|
)
|
||||||
save_precision_input = gr.Dropdown(
|
save_precision = gr.Dropdown(
|
||||||
label='Save precision',
|
label='Save precision',
|
||||||
choices=[
|
choices=[
|
||||||
'float',
|
'float',
|
||||||
@ -689,7 +692,7 @@ def lora_tab(
|
|||||||
],
|
],
|
||||||
value='fp16',
|
value='fp16',
|
||||||
)
|
)
|
||||||
num_cpu_threads_per_process_input = gr.Slider(
|
num_cpu_threads_per_process = gr.Slider(
|
||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=os.cpu_count(),
|
maximum=os.cpu_count(),
|
||||||
step=1,
|
step=1,
|
||||||
@ -697,18 +700,18 @@ def lora_tab(
|
|||||||
value=os.cpu_count(),
|
value=os.cpu_count(),
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
seed_input = gr.Textbox(label='Seed', value=1234)
|
seed = gr.Textbox(label='Seed', value=1234)
|
||||||
max_resolution_input = gr.Textbox(
|
max_resolution = gr.Textbox(
|
||||||
label='Max resolution',
|
label='Max resolution',
|
||||||
value='512,512',
|
value='512,512',
|
||||||
placeholder='512,512',
|
placeholder='512,512',
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
caption_extention_input = gr.Textbox(
|
caption_extention = gr.Textbox(
|
||||||
label='Caption Extension',
|
label='Caption Extension',
|
||||||
placeholder='(Optional) Extension for caption files. default: .caption',
|
placeholder='(Optional) Extension for caption files. default: .caption',
|
||||||
)
|
)
|
||||||
stop_text_encoder_training_input = gr.Slider(
|
stop_text_encoder_training = gr.Slider(
|
||||||
minimum=0,
|
minimum=0,
|
||||||
maximum=100,
|
maximum=100,
|
||||||
value=0,
|
value=0,
|
||||||
@ -716,20 +719,20 @@ def lora_tab(
|
|||||||
label='Stop text encoder training',
|
label='Stop text encoder training',
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
enable_bucket_input = gr.Checkbox(
|
enable_bucket = gr.Checkbox(
|
||||||
label='Enable buckets', value=True
|
label='Enable buckets', value=True
|
||||||
)
|
)
|
||||||
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
|
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
||||||
use_8bit_adam_input = gr.Checkbox(
|
use_8bit_adam = gr.Checkbox(
|
||||||
label='Use 8bit adam', value=True
|
label='Use 8bit adam', value=True
|
||||||
)
|
)
|
||||||
xformers_input = gr.Checkbox(label='Use xformers', value=True)
|
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||||
with gr.Accordion('Advanced Configuration', open=False):
|
with gr.Accordion('Advanced Configuration', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
full_fp16_input = gr.Checkbox(
|
full_fp16 = gr.Checkbox(
|
||||||
label='Full fp16 training (experimental)', value=False
|
label='Full fp16 training (experimental)', value=False
|
||||||
)
|
)
|
||||||
no_token_padding_input = gr.Checkbox(
|
no_token_padding = gr.Checkbox(
|
||||||
label='No token padding', value=False
|
label='No token padding', value=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -754,7 +757,7 @@ def lora_tab(
|
|||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
inputs=[color_aug],
|
inputs=[color_aug],
|
||||||
outputs=[cache_latent_input],
|
outputs=[cache_latent],
|
||||||
)
|
)
|
||||||
clip_skip = gr.Slider(
|
clip_skip = gr.Slider(
|
||||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
||||||
@ -783,10 +786,10 @@ def lora_tab(
|
|||||||
'This section provide Dreambooth tools to help setup your dataset...'
|
'This section provide Dreambooth tools to help setup your dataset...'
|
||||||
)
|
)
|
||||||
gradio_dreambooth_folder_creation_tab(
|
gradio_dreambooth_folder_creation_tab(
|
||||||
train_data_dir_input=train_data_dir_input,
|
train_data_dir_input=train_data_dir,
|
||||||
reg_data_dir_input=reg_data_dir_input,
|
reg_data_dir_input=reg_data_dir,
|
||||||
output_dir_input=output_dir_input,
|
output_dir_input=output_dir,
|
||||||
logging_dir_input=logging_dir_input,
|
logging_dir_input=logging_dir,
|
||||||
)
|
)
|
||||||
gradio_dataset_balancing_tab()
|
gradio_dataset_balancing_tab()
|
||||||
gradio_merge_lora_tab()
|
gradio_merge_lora_tab()
|
||||||
@ -794,32 +797,32 @@ def lora_tab(
|
|||||||
button_run = gr.Button('Train model')
|
button_run = gr.Button('Train model')
|
||||||
|
|
||||||
settings_list = [
|
settings_list = [
|
||||||
pretrained_model_name_or_path_input,
|
pretrained_model_name_or_path,
|
||||||
v2_input,
|
v2,
|
||||||
v_parameterization_input,
|
v_parameterization,
|
||||||
logging_dir_input,
|
logging_dir,
|
||||||
train_data_dir_input,
|
train_data_dir,
|
||||||
reg_data_dir_input,
|
reg_data_dir,
|
||||||
output_dir_input,
|
output_dir,
|
||||||
max_resolution_input,
|
max_resolution,
|
||||||
lr_scheduler_input,
|
lr_scheduler,
|
||||||
lr_warmup_input,
|
lr_warmup,
|
||||||
train_batch_size_input,
|
train_batch_size,
|
||||||
epoch_input,
|
epoch,
|
||||||
save_every_n_epochs_input,
|
save_every_n_epochs,
|
||||||
mixed_precision_input,
|
mixed_precision,
|
||||||
save_precision_input,
|
save_precision,
|
||||||
seed_input,
|
seed,
|
||||||
num_cpu_threads_per_process_input,
|
num_cpu_threads_per_process,
|
||||||
cache_latent_input,
|
cache_latent,
|
||||||
caption_extention_input,
|
caption_extention,
|
||||||
enable_bucket_input,
|
enable_bucket,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
full_fp16_input,
|
full_fp16,
|
||||||
no_token_padding_input,
|
no_token_padding,
|
||||||
stop_text_encoder_training_input,
|
stop_text_encoder_training,
|
||||||
use_8bit_adam_input,
|
use_8bit_adam,
|
||||||
xformers_input,
|
xformers,
|
||||||
save_model_as_dropdown,
|
save_model_as_dropdown,
|
||||||
shuffle_caption,
|
shuffle_caption,
|
||||||
save_state,
|
save_state,
|
||||||
@ -835,6 +838,7 @@ def lora_tab(
|
|||||||
gradient_accumulation_steps,
|
gradient_accumulation_steps,
|
||||||
mem_eff_attn,
|
mem_eff_attn,
|
||||||
output_name,
|
output_name,
|
||||||
|
model_list,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
@ -861,10 +865,10 @@ def lora_tab(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
train_data_dir_input,
|
train_data_dir,
|
||||||
reg_data_dir_input,
|
reg_data_dir,
|
||||||
output_dir_input,
|
output_dir,
|
||||||
logging_dir_input,
|
logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user