Fix issue with lora model loading
This commit is contained in:
parent
dc5afbb057
commit
11fbc63440
180
finetune/prepare_buckets_latents_new.py
Normal file
180
finetune/prepare_buckets_latents_new.py
Normal file
@ -0,0 +1,180 @@
|
||||
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||
# (c) 2022 Kohya S. @kohya_ss
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKL
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
import library.model_util as model_util
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_latents(vae, images, weight_dtype):
|
||||
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
|
||||
img_tensors = torch.stack(img_tensors)
|
||||
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
|
||||
return latents
|
||||
|
||||
|
||||
def main(args):
|
||||
# image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
# glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
# print(f"found {len(image_paths)} images.")
|
||||
|
||||
|
||||
if os.path.exists(args.in_json):
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
vae.eval()
|
||||
vae.to(DEVICE, dtype=weight_dtype)
|
||||
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||
max_reso, args.min_bucket_reso, args.max_bucket_reso)
|
||||
|
||||
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
||||
buckets_imgs = [[] for _ in range(len(bucket_resos))]
|
||||
bucket_counts = [0 for _ in range(len(bucket_resos))]
|
||||
img_ar_errors = []
|
||||
for i, image_path in enumerate(tqdm(metadata, smoothing=0.0)):
|
||||
image_key = image_path
|
||||
if image_key not in metadata:
|
||||
metadata[image_key] = {}
|
||||
|
||||
image = Image.open(image_path)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert("RGB")
|
||||
|
||||
aspect_ratio = image.width / image.height
|
||||
ar_errors = bucket_aspect_ratios - aspect_ratio
|
||||
bucket_id = np.abs(ar_errors).argmin()
|
||||
reso = bucket_resos[bucket_id]
|
||||
ar_error = ar_errors[bucket_id]
|
||||
img_ar_errors.append(abs(ar_error))
|
||||
|
||||
# どのサイズにリサイズするか→トリミングする方向で
|
||||
if ar_error <= 0: # 横が長い→縦を合わせる
|
||||
scale = reso[1] / image.height
|
||||
else:
|
||||
scale = reso[0] / image.width
|
||||
|
||||
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
|
||||
|
||||
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
|
||||
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
|
||||
|
||||
assert resized_size[0] == reso[0] or resized_size[1] == reso[
|
||||
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
|
||||
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
|
||||
|
||||
# 画像をリサイズしてトリミングする
|
||||
# PILにinter_areaがないのでcv2で……
|
||||
image = np.array(image)
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
||||
if resized_size[0] > reso[0]:
|
||||
trim_size = resized_size[0] - reso[0]
|
||||
image = image[:, trim_size//2:trim_size//2 + reso[0]]
|
||||
elif resized_size[1] > reso[1]:
|
||||
trim_size = resized_size[1] - reso[1]
|
||||
image = image[trim_size//2:trim_size//2 + reso[1]]
|
||||
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
|
||||
# # debug
|
||||
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
|
||||
|
||||
# バッチへ追加
|
||||
buckets_imgs[bucket_id].append((image_key, reso, image))
|
||||
bucket_counts[bucket_id] += 1
|
||||
metadata[image_key]['train_resolution'] = reso
|
||||
|
||||
# バッチを推論するか判定して推論する
|
||||
is_last = i == len(metadata) - 1
|
||||
for j in range(len(buckets_imgs)):
|
||||
bucket = buckets_imgs[j]
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
|
||||
|
||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||
npz_file_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
np.savez(os.path.join(os.path.dirname(image_key), npz_file_name), latent)
|
||||
|
||||
# flip
|
||||
if args.flip_aug:
|
||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||
|
||||
for (image_key, reso, _), latent in zip(bucket, latents):
|
||||
npz_file_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
np.savez(os.path.join(os.path.dirname(image_key), npz_file_name + '_flip'), latent)
|
||||
|
||||
bucket.clear()
|
||||
|
||||
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
|
||||
# metadataを書き出して終わり
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument("--max_resolution", type=str, default="512,512",
|
||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||
parser.add_argument("--full_path", action="store_true",
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||
parser.add_argument("--flip_aug", action="store_true",
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
789
finetune_gui copy.py
Normal file
789
finetune_gui copy.py
Normal file
@ -0,0 +1,789 @@
|
||||
import gradio as gr
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import pathlib
|
||||
import shutil
|
||||
import argparse
|
||||
from library.common_gui import (
|
||||
get_folder_path,
|
||||
get_file_path,
|
||||
get_any_file_path,
|
||||
get_saveasfile_path,
|
||||
)
|
||||
from library.utilities import utilities_tab
|
||||
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
document_symbol = '\U0001F4C4' # 📄
|
||||
|
||||
|
||||
def save_configuration(
|
||||
save_as,
|
||||
file_path,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
train_dir,
|
||||
image_folder,
|
||||
output_dir,
|
||||
logging_dir,
|
||||
max_resolution,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
batch_size,
|
||||
flip_aug,
|
||||
caption_metadata_filename,
|
||||
latent_metadata_filename,
|
||||
full_path,
|
||||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
dataset_repeats,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
mixed_precision,
|
||||
save_precision,
|
||||
seed,
|
||||
num_cpu_threads_per_process,
|
||||
train_text_encoder,
|
||||
create_caption,
|
||||
create_buckets,
|
||||
save_model_as,
|
||||
caption_extension,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
clip_skip,
|
||||
):
|
||||
original_file_path = file_path
|
||||
|
||||
save_as_bool = True if save_as.get('label') == 'True' else False
|
||||
|
||||
if save_as_bool:
|
||||
print('Save as...')
|
||||
file_path = get_saveasfile_path(file_path)
|
||||
else:
|
||||
print('Save...')
|
||||
if file_path == None or file_path == '':
|
||||
file_path = get_saveasfile_path(file_path)
|
||||
|
||||
# print(file_path)
|
||||
|
||||
if file_path == None:
|
||||
return original_file_path
|
||||
|
||||
# Return the values of the variables as a dictionary
|
||||
variables = {
|
||||
'pretrained_model_name_or_path': pretrained_model_name_or_path,
|
||||
'v2': v2,
|
||||
'v_parameterization': v_parameterization,
|
||||
'train_dir': train_dir,
|
||||
'image_folder': image_folder,
|
||||
'output_dir': output_dir,
|
||||
'logging_dir': logging_dir,
|
||||
'max_resolution': max_resolution,
|
||||
'min_bucket_reso': min_bucket_reso,
|
||||
'max_bucket_reso': max_bucket_reso,
|
||||
'batch_size': batch_size,
|
||||
'flip_aug': flip_aug,
|
||||
'caption_metadata_filename': caption_metadata_filename,
|
||||
'latent_metadata_filename': latent_metadata_filename,
|
||||
'full_path': full_path,
|
||||
'learning_rate': learning_rate,
|
||||
'lr_scheduler': lr_scheduler,
|
||||
'lr_warmup': lr_warmup,
|
||||
'dataset_repeats': dataset_repeats,
|
||||
'train_batch_size': train_batch_size,
|
||||
'epoch': epoch,
|
||||
'save_every_n_epochs': save_every_n_epochs,
|
||||
'mixed_precision': mixed_precision,
|
||||
'save_precision': save_precision,
|
||||
'seed': seed,
|
||||
'num_cpu_threads_per_process': num_cpu_threads_per_process,
|
||||
'train_text_encoder': train_text_encoder,
|
||||
'create_buckets': create_buckets,
|
||||
'create_caption': create_caption,
|
||||
'save_model_as': save_model_as,
|
||||
'caption_extension': caption_extension,
|
||||
'use_8bit_adam': use_8bit_adam,
|
||||
'xformers': xformers,
|
||||
'clip_skip': clip_skip,
|
||||
}
|
||||
|
||||
# Save the data to the selected file
|
||||
with open(file_path, 'w') as file:
|
||||
json.dump(variables, file)
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def open_config_file(
|
||||
file_path,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
train_dir,
|
||||
image_folder,
|
||||
output_dir,
|
||||
logging_dir,
|
||||
max_resolution,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
batch_size,
|
||||
flip_aug,
|
||||
caption_metadata_filename,
|
||||
latent_metadata_filename,
|
||||
full_path,
|
||||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
dataset_repeats,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
mixed_precision,
|
||||
save_precision,
|
||||
seed,
|
||||
num_cpu_threads_per_process,
|
||||
train_text_encoder,
|
||||
create_caption,
|
||||
create_buckets,
|
||||
save_model_as,
|
||||
caption_extension,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
clip_skip,
|
||||
):
|
||||
original_file_path = file_path
|
||||
file_path = get_file_path(file_path)
|
||||
|
||||
if file_path != '' and file_path != None:
|
||||
print(file_path)
|
||||
# load variables from JSON file
|
||||
with open(file_path, 'r') as f:
|
||||
my_data = json.load(f)
|
||||
else:
|
||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||
my_data = {}
|
||||
|
||||
# Return the values of the variables as a dictionary
|
||||
return (
|
||||
file_path,
|
||||
my_data.get(
|
||||
'pretrained_model_name_or_path', pretrained_model_name_or_path
|
||||
),
|
||||
my_data.get('v2', v2),
|
||||
my_data.get('v_parameterization', v_parameterization),
|
||||
my_data.get('train_dir', train_dir),
|
||||
my_data.get('image_folder', image_folder),
|
||||
my_data.get('output_dir', output_dir),
|
||||
my_data.get('logging_dir', logging_dir),
|
||||
my_data.get('max_resolution', max_resolution),
|
||||
my_data.get('min_bucket_reso', min_bucket_reso),
|
||||
my_data.get('max_bucket_reso', max_bucket_reso),
|
||||
my_data.get('batch_size', batch_size),
|
||||
my_data.get('flip_aug', flip_aug),
|
||||
my_data.get('caption_metadata_filename', caption_metadata_filename),
|
||||
my_data.get('latent_metadata_filename', latent_metadata_filename),
|
||||
my_data.get('full_path', full_path),
|
||||
my_data.get('learning_rate', learning_rate),
|
||||
my_data.get('lr_scheduler', lr_scheduler),
|
||||
my_data.get('lr_warmup', lr_warmup),
|
||||
my_data.get('dataset_repeats', dataset_repeats),
|
||||
my_data.get('train_batch_size', train_batch_size),
|
||||
my_data.get('epoch', epoch),
|
||||
my_data.get('save_every_n_epochs', save_every_n_epochs),
|
||||
my_data.get('mixed_precision', mixed_precision),
|
||||
my_data.get('save_precision', save_precision),
|
||||
my_data.get('seed', seed),
|
||||
my_data.get(
|
||||
'num_cpu_threads_per_process', num_cpu_threads_per_process
|
||||
),
|
||||
my_data.get('train_text_encoder', train_text_encoder),
|
||||
my_data.get('create_buckets', create_buckets),
|
||||
my_data.get('create_caption', create_caption),
|
||||
my_data.get('save_model_as', save_model_as),
|
||||
my_data.get('caption_extension', caption_extension),
|
||||
my_data.get('use_8bit_adam', use_8bit_adam),
|
||||
my_data.get('xformers', xformers),
|
||||
my_data.get('clip_skip', clip_skip),
|
||||
)
|
||||
|
||||
|
||||
def train_model(
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
train_dir,
|
||||
image_folder,
|
||||
output_dir,
|
||||
logging_dir,
|
||||
max_resolution,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
batch_size,
|
||||
flip_aug,
|
||||
caption_metadata_filename,
|
||||
latent_metadata_filename,
|
||||
full_path,
|
||||
learning_rate,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
dataset_repeats,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
mixed_precision,
|
||||
save_precision,
|
||||
seed,
|
||||
num_cpu_threads_per_process,
|
||||
train_text_encoder,
|
||||
generate_caption_database,
|
||||
generate_image_buckets,
|
||||
save_model_as,
|
||||
caption_extension,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
clip_skip,
|
||||
):
|
||||
def save_inference_file(output_dir, v2, v_parameterization):
|
||||
# Copy inference model for v2 if required
|
||||
if v2 and v_parameterization:
|
||||
print(f'Saving v2-inference-v.yaml as {output_dir}/last.yaml')
|
||||
shutil.copy(
|
||||
f'./v2_inference/v2-inference-v.yaml',
|
||||
f'{output_dir}/last.yaml',
|
||||
)
|
||||
elif v2:
|
||||
print(f'Saving v2-inference.yaml as {output_dir}/last.yaml')
|
||||
shutil.copy(
|
||||
f'./v2_inference/v2-inference.yaml',
|
||||
f'{output_dir}/last.yaml',
|
||||
)
|
||||
|
||||
# create caption json file
|
||||
if generate_caption_database:
|
||||
if not os.path.exists(train_dir):
|
||||
os.mkdir(train_dir)
|
||||
|
||||
for root, dirs, files in os.walk(image_folder):
|
||||
for dir in dirs:
|
||||
print(os.path.join(root, dir))
|
||||
|
||||
run_cmd = (
|
||||
f'./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py'
|
||||
)
|
||||
if caption_extension == '':
|
||||
run_cmd += f' --caption_extension=".txt"'
|
||||
else:
|
||||
run_cmd += f' --caption_extension={caption_extension}'
|
||||
run_cmd += f' "{os.path.join(root, dir)}"'
|
||||
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
|
||||
if full_path:
|
||||
run_cmd += f' --full_path'
|
||||
|
||||
print(run_cmd)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
# create images buckets
|
||||
if generate_image_buckets:
|
||||
run_cmd = (
|
||||
f'./venv/Scripts/python.exe finetune/prepare_buckets_latents.py'
|
||||
)
|
||||
run_cmd += f' "crap"'
|
||||
run_cmd += f' "{train_dir}/{caption_metadata_filename}"'
|
||||
run_cmd += f' "{train_dir}/{latent_metadata_filename}"'
|
||||
run_cmd += f' "{pretrained_model_name_or_path}"'
|
||||
run_cmd += f' --batch_size={batch_size}'
|
||||
run_cmd += f' --max_resolution={max_resolution}'
|
||||
run_cmd += f' --min_bucket_reso={min_bucket_reso}'
|
||||
run_cmd += f' --max_bucket_reso={max_bucket_reso}'
|
||||
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||
if flip_aug:
|
||||
run_cmd += f' --flip_aug'
|
||||
if full_path:
|
||||
run_cmd += f' --full_path'
|
||||
|
||||
print(run_cmd)
|
||||
|
||||
# Run the command
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
image_num = 0
|
||||
for root, dirs, files in os.walk(image_folder):
|
||||
for dir in dirs:
|
||||
image_num += len(
|
||||
[f for f in os.listdir(os.path.join(root, dir)) if f.endswith('.npz')]
|
||||
)
|
||||
print(f'image_num = {image_num}')
|
||||
|
||||
repeats = int(image_num) * int(dataset_repeats)
|
||||
print(f'repeats = {str(repeats)}')
|
||||
|
||||
# calculate max_train_steps
|
||||
max_train_steps = int(
|
||||
math.ceil(float(repeats) / int(train_batch_size) * int(epoch))
|
||||
)
|
||||
|
||||
# Divide by two because flip augmentation create two copied of the source images
|
||||
if flip_aug:
|
||||
max_train_steps = int(math.ceil(float(max_train_steps) / 2))
|
||||
|
||||
print(f'max_train_steps = {max_train_steps}')
|
||||
|
||||
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
|
||||
print(f'lr_warmup_steps = {lr_warmup_steps}')
|
||||
|
||||
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
|
||||
if v2:
|
||||
run_cmd += ' --v2'
|
||||
if v_parameterization:
|
||||
run_cmd += ' --v_parameterization'
|
||||
if train_text_encoder:
|
||||
run_cmd += ' --train_text_encoder'
|
||||
if use_8bit_adam:
|
||||
run_cmd += f' --use_8bit_adam'
|
||||
if xformers:
|
||||
run_cmd += f' --xformers'
|
||||
run_cmd += (
|
||||
f' --pretrained_model_name_or_path="{pretrained_model_name_or_path}"'
|
||||
)
|
||||
run_cmd += f' --in_json="{train_dir}/{latent_metadata_filename}"'
|
||||
run_cmd += f' --train_data_dir="{image_folder}"'
|
||||
run_cmd += f' --output_dir="{output_dir}"'
|
||||
if not logging_dir == '':
|
||||
run_cmd += f' --logging_dir="{logging_dir}"'
|
||||
run_cmd += f' --train_batch_size={train_batch_size}'
|
||||
run_cmd += f' --dataset_repeats={dataset_repeats}'
|
||||
run_cmd += f' --learning_rate={learning_rate}'
|
||||
run_cmd += f' --lr_scheduler={lr_scheduler}'
|
||||
run_cmd += f' --lr_warmup_steps={lr_warmup_steps}'
|
||||
run_cmd += f' --max_train_steps={max_train_steps}'
|
||||
run_cmd += f' --mixed_precision={mixed_precision}'
|
||||
run_cmd += f' --save_every_n_epochs={save_every_n_epochs}'
|
||||
run_cmd += f' --seed={seed}'
|
||||
run_cmd += f' --save_precision={save_precision}'
|
||||
if not save_model_as == 'same as source model':
|
||||
run_cmd += f' --save_model_as={save_model_as}'
|
||||
if int(clip_skip) > 1:
|
||||
run_cmd += f' --clip_skip={str(clip_skip)}'
|
||||
|
||||
print(run_cmd)
|
||||
# Run the command
|
||||
subprocess.run(run_cmd)
|
||||
|
||||
# check if output_dir/last is a folder... therefore it is a diffuser model
|
||||
last_dir = pathlib.Path(f'{output_dir}/last')
|
||||
|
||||
if not last_dir.is_dir():
|
||||
# Copy inference model for v2 if required
|
||||
save_inference_file(output_dir, v2, v_parameterization)
|
||||
|
||||
|
||||
def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
|
||||
# define a list of substrings to search for
|
||||
substrings_v2 = [
|
||||
'stabilityai/stable-diffusion-2-1-base',
|
||||
'stabilityai/stable-diffusion-2-base',
|
||||
]
|
||||
|
||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
|
||||
if str(value) in substrings_v2:
|
||||
print('SD v2 model detected. Setting --v2 parameter')
|
||||
v2 = True
|
||||
v_parameterization = False
|
||||
|
||||
return value, v2, v_parameterization
|
||||
|
||||
# define a list of substrings to search for v-objective
|
||||
substrings_v_parameterization = [
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
'stabilityai/stable-diffusion-2',
|
||||
]
|
||||
|
||||
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
|
||||
if str(value) in substrings_v_parameterization:
|
||||
print(
|
||||
'SD v2 v_parameterization detected. Setting --v2 parameter and --v_parameterization'
|
||||
)
|
||||
v2 = True
|
||||
v_parameterization = True
|
||||
|
||||
return value, v2, v_parameterization
|
||||
|
||||
# define a list of substrings to v1.x
|
||||
substrings_v1_model = [
|
||||
'CompVis/stable-diffusion-v1-4',
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
]
|
||||
|
||||
if str(value) in substrings_v1_model:
|
||||
v2 = False
|
||||
v_parameterization = False
|
||||
|
||||
return value, v2, v_parameterization
|
||||
|
||||
if value == 'custom':
|
||||
value = ''
|
||||
v2 = False
|
||||
v_parameterization = False
|
||||
|
||||
return value, v2, v_parameterization
|
||||
|
||||
|
||||
def remove_doublequote(file_path):
|
||||
if file_path != None:
|
||||
file_path = file_path.replace('"', '')
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def UI(username, password):
|
||||
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Finetune'):
|
||||
finetune_tab()
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(enable_dreambooth_tab=False)
|
||||
|
||||
# Show the interface
|
||||
if not username == '':
|
||||
interface.launch(auth=(username, password))
|
||||
else:
|
||||
interface.launch()
|
||||
|
||||
|
||||
def finetune_tab():
|
||||
dummy_ft_true = gr.Label(value=True, visible=False)
|
||||
dummy_ft_false = gr.Label(value=False, visible=False)
|
||||
gr.Markdown('Train a custom model using kohya finetune python code...')
|
||||
with gr.Accordion('Configuration file', open=False):
|
||||
with gr.Row():
|
||||
button_open_config = gr.Button(
|
||||
f'Open {folder_symbol}', elem_id='open_folder'
|
||||
)
|
||||
button_save_config = gr.Button(
|
||||
f'Save {save_style_symbol}', elem_id='open_folder'
|
||||
)
|
||||
button_save_as_config = gr.Button(
|
||||
f'Save as... {save_style_symbol}',
|
||||
elem_id='open_folder',
|
||||
)
|
||||
config_file_name = gr.Textbox(
|
||||
label='', placeholder='type file path or use buttons...'
|
||||
)
|
||||
config_file_name.change(
|
||||
remove_doublequote,
|
||||
inputs=[config_file_name],
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
with gr.Tab('Source model'):
|
||||
# Define the input elements
|
||||
with gr.Row():
|
||||
pretrained_model_name_or_path_input = gr.Textbox(
|
||||
label='Pretrained model name or path',
|
||||
placeholder='enter the path to custom model or name of pretrained model',
|
||||
)
|
||||
pretrained_model_name_or_path_file = gr.Button(
|
||||
document_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
pretrained_model_name_or_path_file.click(
|
||||
get_any_file_path,
|
||||
inputs=pretrained_model_name_or_path_input,
|
||||
outputs=pretrained_model_name_or_path_input,
|
||||
)
|
||||
pretrained_model_name_or_path_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
pretrained_model_name_or_path_folder.click(
|
||||
get_folder_path,
|
||||
inputs=pretrained_model_name_or_path_input,
|
||||
outputs=pretrained_model_name_or_path_input,
|
||||
)
|
||||
model_list = gr.Dropdown(
|
||||
label='(Optional) Model Quick Pick',
|
||||
choices=[
|
||||
'custom',
|
||||
'stabilityai/stable-diffusion-2-1-base',
|
||||
'stabilityai/stable-diffusion-2-base',
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
'stabilityai/stable-diffusion-2',
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
'CompVis/stable-diffusion-v1-4',
|
||||
],
|
||||
)
|
||||
save_model_as_dropdown = gr.Dropdown(
|
||||
label='Save trained model as',
|
||||
choices=[
|
||||
'same as source model',
|
||||
'ckpt',
|
||||
'diffusers',
|
||||
'diffusers_safetensors',
|
||||
'safetensors',
|
||||
],
|
||||
value='same as source model',
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
v2_input = gr.Checkbox(label='v2', value=True)
|
||||
v_parameterization_input = gr.Checkbox(
|
||||
label='v_parameterization', value=False
|
||||
)
|
||||
model_list.change(
|
||||
set_pretrained_model_name_or_path_input,
|
||||
inputs=[model_list, v2_input, v_parameterization_input],
|
||||
outputs=[
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
],
|
||||
)
|
||||
with gr.Tab('Folders'):
|
||||
with gr.Row():
|
||||
train_dir_input = gr.Textbox(
|
||||
label='Training config folder',
|
||||
placeholder='folder where the training configuration files will be saved',
|
||||
)
|
||||
train_dir_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
train_dir_folder.click(get_folder_path, outputs=train_dir_input)
|
||||
|
||||
image_folder_input = gr.Textbox(
|
||||
label='Training Image folder',
|
||||
placeholder='folder where the training images are located',
|
||||
)
|
||||
image_folder_input_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
image_folder_input_folder.click(
|
||||
get_folder_path, outputs=image_folder_input
|
||||
)
|
||||
with gr.Row():
|
||||
output_dir_input = gr.Textbox(
|
||||
label='Output folder',
|
||||
placeholder='folder where the model will be saved',
|
||||
)
|
||||
output_dir_input_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
output_dir_input_folder.click(
|
||||
get_folder_path, outputs=output_dir_input
|
||||
)
|
||||
|
||||
logging_dir_input = gr.Textbox(
|
||||
label='Logging folder',
|
||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||
)
|
||||
logging_dir_input_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
logging_dir_input_folder.click(
|
||||
get_folder_path, outputs=logging_dir_input
|
||||
)
|
||||
train_dir_input.change(
|
||||
remove_doublequote,
|
||||
inputs=[train_dir_input],
|
||||
outputs=[train_dir_input],
|
||||
)
|
||||
image_folder_input.change(
|
||||
remove_doublequote,
|
||||
inputs=[image_folder_input],
|
||||
outputs=[image_folder_input],
|
||||
)
|
||||
output_dir_input.change(
|
||||
remove_doublequote,
|
||||
inputs=[output_dir_input],
|
||||
outputs=[output_dir_input],
|
||||
)
|
||||
with gr.Tab('Dataset preparation'):
|
||||
with gr.Row():
|
||||
max_resolution_input = gr.Textbox(
|
||||
label='Resolution (width,height)', value='512,512'
|
||||
)
|
||||
min_bucket_reso = gr.Textbox(
|
||||
label='Min bucket resolution', value='256'
|
||||
)
|
||||
max_bucket_reso = gr.Textbox(
|
||||
label='Max bucket resolution', value='1024'
|
||||
)
|
||||
batch_size = gr.Textbox(label='Batch size', value='1')
|
||||
with gr.Accordion('Advanced parameters', open=False):
|
||||
with gr.Row():
|
||||
caption_metadata_filename = gr.Textbox(
|
||||
label='Caption metadata filename', value='meta_cap.json'
|
||||
)
|
||||
latent_metadata_filename = gr.Textbox(
|
||||
label='Latent metadata filename', value='meta_lat.json'
|
||||
)
|
||||
full_path = gr.Checkbox(label='Use full path', value=True)
|
||||
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
|
||||
with gr.Tab('Training parameters'):
|
||||
with gr.Row():
|
||||
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
|
||||
lr_scheduler_input = gr.Dropdown(
|
||||
label='LR Scheduler',
|
||||
choices=[
|
||||
'constant',
|
||||
'constant_with_warmup',
|
||||
'cosine',
|
||||
'cosine_with_restarts',
|
||||
'linear',
|
||||
'polynomial',
|
||||
],
|
||||
value='constant',
|
||||
)
|
||||
lr_warmup_input = gr.Textbox(label='LR warmup', value=0)
|
||||
with gr.Row():
|
||||
dataset_repeats_input = gr.Textbox(
|
||||
label='Dataset repeats', value=40
|
||||
)
|
||||
train_batch_size_input = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=32,
|
||||
label='Train batch size',
|
||||
value=1,
|
||||
step=1,
|
||||
)
|
||||
epoch_input = gr.Textbox(label='Epoch', value=1)
|
||||
save_every_n_epochs_input = gr.Textbox(
|
||||
label='Save every N epochs', value=1
|
||||
)
|
||||
with gr.Row():
|
||||
mixed_precision_input = gr.Dropdown(
|
||||
label='Mixed precision',
|
||||
choices=[
|
||||
'no',
|
||||
'fp16',
|
||||
'bf16',
|
||||
],
|
||||
value='fp16',
|
||||
)
|
||||
save_precision_input = gr.Dropdown(
|
||||
label='Save precision',
|
||||
choices=[
|
||||
'float',
|
||||
'fp16',
|
||||
'bf16',
|
||||
],
|
||||
value='fp16',
|
||||
)
|
||||
num_cpu_threads_per_process_input = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=os.cpu_count(),
|
||||
step=1,
|
||||
label='Number of CPU threads per process',
|
||||
value=os.cpu_count(),
|
||||
)
|
||||
seed_input = gr.Textbox(label='Seed', value=1234)
|
||||
with gr.Row():
|
||||
caption_extention_input = gr.Textbox(
|
||||
label='Caption Extension',
|
||||
placeholder='(Optional) Extension for caption files. default: .txt',
|
||||
)
|
||||
train_text_encoder_input = gr.Checkbox(
|
||||
label='Train text encoder', value=True
|
||||
)
|
||||
with gr.Accordion('Advanced parameters', open=False):
|
||||
with gr.Row():
|
||||
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True)
|
||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||
clip_skip = gr.Slider(
|
||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
||||
)
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
create_caption = gr.Checkbox(
|
||||
label='Generate caption metadata', value=True
|
||||
)
|
||||
create_buckets = gr.Checkbox(
|
||||
label='Generate image buckets metadata', value=True
|
||||
)
|
||||
|
||||
button_run = gr.Button('Train model')
|
||||
|
||||
settings_list = [
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
train_dir_input,
|
||||
image_folder_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
max_resolution_input,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
batch_size,
|
||||
flip_aug,
|
||||
caption_metadata_filename,
|
||||
latent_metadata_filename,
|
||||
full_path,
|
||||
learning_rate_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
dataset_repeats_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
train_text_encoder_input,
|
||||
create_caption,
|
||||
create_buckets,
|
||||
save_model_as_dropdown,
|
||||
caption_extention_input,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
clip_skip,
|
||||
]
|
||||
|
||||
button_run.click(train_model, inputs=settings_list)
|
||||
|
||||
button_open_config.click(
|
||||
open_config_file,
|
||||
inputs=[config_file_name] + settings_list,
|
||||
outputs=[config_file_name] + settings_list,
|
||||
)
|
||||
|
||||
button_save_config.click(
|
||||
save_configuration,
|
||||
inputs=[dummy_ft_false, config_file_name] + settings_list,
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
|
||||
button_save_as_config.click(
|
||||
save_configuration,
|
||||
inputs=[dummy_ft_true, config_file_name] + settings_list,
|
||||
outputs=[config_file_name],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--username', type=str, default='', help='Username for authentication'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--password', type=str, default='', help='Password for authentication'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
UI(username=args.username, password=args.password)
|
250
lora_gui.py
250
lora_gui.py
@ -79,6 +79,7 @@ def save_configuration(
|
||||
gradient_accumulation_steps,
|
||||
mem_eff_attn,
|
||||
output_name,
|
||||
model_list,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -120,32 +121,32 @@ def save_configuration(
|
||||
|
||||
def open_configuration(
|
||||
file_path,
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
logging_dir_input,
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
max_resolution_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
cache_latent_input,
|
||||
caption_extention_input,
|
||||
enable_bucket_input,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
output_dir,
|
||||
max_resolution,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
mixed_precision,
|
||||
save_precision,
|
||||
seed,
|
||||
num_cpu_threads_per_process,
|
||||
cache_latent,
|
||||
caption_extention,
|
||||
enable_bucket,
|
||||
gradient_checkpointing,
|
||||
full_fp16_input,
|
||||
no_token_padding_input,
|
||||
stop_text_encoder_training_input,
|
||||
use_8bit_adam_input,
|
||||
xformers_input,
|
||||
full_fp16,
|
||||
no_token_padding,
|
||||
stop_text_encoder_training,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
save_model_as_dropdown,
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
@ -161,27 +162,28 @@ def open_configuration(
|
||||
gradient_accumulation_steps,
|
||||
mem_eff_attn,
|
||||
output_name,
|
||||
model_list,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
|
||||
|
||||
original_file_path = file_path
|
||||
file_path = get_file_path(file_path)
|
||||
|
||||
if not file_path == '' and not file_path == None:
|
||||
# load variables from JSON file
|
||||
with open(file_path, 'r') as f:
|
||||
my_data_lora = json.load(f)
|
||||
my_data = json.load(f)
|
||||
print("Loading config...")
|
||||
else:
|
||||
file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action
|
||||
my_data_lora = {}
|
||||
|
||||
my_data = {}
|
||||
|
||||
values = [file_path]
|
||||
for key, value in parameters:
|
||||
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
|
||||
if not key in ['file_path']:
|
||||
values.append(my_data_lora.get(key, value))
|
||||
values.append(my_data.get(key, value))
|
||||
return tuple(values)
|
||||
|
||||
|
||||
@ -227,6 +229,7 @@ def train_model(
|
||||
gradient_accumulation_steps,
|
||||
mem_eff_attn,
|
||||
output_name,
|
||||
model_list,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -480,7 +483,7 @@ def lora_tab(
|
||||
with gr.Tab('Source model'):
|
||||
# Define the input elements
|
||||
with gr.Row():
|
||||
pretrained_model_name_or_path_input = gr.Textbox(
|
||||
pretrained_model_name_or_path = gr.Textbox(
|
||||
label='Pretrained model name or path',
|
||||
placeholder='enter the path to custom model or name of pretrained model',
|
||||
)
|
||||
@ -489,15 +492,15 @@ def lora_tab(
|
||||
)
|
||||
pretrained_model_name_or_path_file.click(
|
||||
get_any_file_path,
|
||||
inputs=[pretrained_model_name_or_path_input],
|
||||
outputs=pretrained_model_name_or_path_input,
|
||||
inputs=[pretrained_model_name_or_path],
|
||||
outputs=pretrained_model_name_or_path,
|
||||
)
|
||||
pretrained_model_name_or_path_folder = gr.Button(
|
||||
folder_symbol, elem_id='open_folder_small'
|
||||
)
|
||||
pretrained_model_name_or_path_folder.click(
|
||||
get_folder_path,
|
||||
outputs=pretrained_model_name_or_path_input,
|
||||
outputs=pretrained_model_name_or_path,
|
||||
)
|
||||
model_list = gr.Dropdown(
|
||||
label='(Optional) Model Quick Pick',
|
||||
@ -524,67 +527,67 @@ def lora_tab(
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
v2_input = gr.Checkbox(label='v2', value=True)
|
||||
v_parameterization_input = gr.Checkbox(
|
||||
v2 = gr.Checkbox(label='v2', value=True)
|
||||
v_parameterization = gr.Checkbox(
|
||||
label='v_parameterization', value=False
|
||||
)
|
||||
pretrained_model_name_or_path_input.change(
|
||||
pretrained_model_name_or_path.change(
|
||||
remove_doublequote,
|
||||
inputs=[pretrained_model_name_or_path_input],
|
||||
outputs=[pretrained_model_name_or_path_input],
|
||||
inputs=[pretrained_model_name_or_path],
|
||||
outputs=[pretrained_model_name_or_path],
|
||||
)
|
||||
model_list.change(
|
||||
set_pretrained_model_name_or_path_input,
|
||||
inputs=[model_list, v2_input, v_parameterization_input],
|
||||
inputs=[model_list, v2, v_parameterization],
|
||||
outputs=[
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
],
|
||||
)
|
||||
|
||||
with gr.Tab('Folders'):
|
||||
with gr.Row():
|
||||
train_data_dir_input = gr.Textbox(
|
||||
train_data_dir = gr.Textbox(
|
||||
label='Image folder',
|
||||
placeholder='Folder where the training folders containing the images are located',
|
||||
)
|
||||
train_data_dir_input_folder = gr.Button(
|
||||
train_data_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
train_data_dir_input_folder.click(
|
||||
get_folder_path, outputs=train_data_dir_input
|
||||
train_data_dir_folder.click(
|
||||
get_folder_path, outputs=train_data_dir
|
||||
)
|
||||
reg_data_dir_input = gr.Textbox(
|
||||
reg_data_dir = gr.Textbox(
|
||||
label='Regularisation folder',
|
||||
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
|
||||
)
|
||||
reg_data_dir_input_folder = gr.Button(
|
||||
reg_data_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
reg_data_dir_input_folder.click(
|
||||
get_folder_path, outputs=reg_data_dir_input
|
||||
reg_data_dir_folder.click(
|
||||
get_folder_path, outputs=reg_data_dir
|
||||
)
|
||||
with gr.Row():
|
||||
output_dir_input = gr.Textbox(
|
||||
output_dir = gr.Textbox(
|
||||
label='Output folder',
|
||||
placeholder='Folder to output trained model',
|
||||
)
|
||||
output_dir_input_folder = gr.Button(
|
||||
output_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
output_dir_input_folder.click(
|
||||
get_folder_path, outputs=output_dir_input
|
||||
output_dir_folder.click(
|
||||
get_folder_path, outputs=output_dir
|
||||
)
|
||||
logging_dir_input = gr.Textbox(
|
||||
logging_dir = gr.Textbox(
|
||||
label='Logging folder',
|
||||
placeholder='Optional: enable logging and output TensorBoard log to this folder',
|
||||
)
|
||||
logging_dir_input_folder = gr.Button(
|
||||
logging_dir_folder = gr.Button(
|
||||
'📂', elem_id='open_folder_small'
|
||||
)
|
||||
logging_dir_input_folder.click(
|
||||
get_folder_path, outputs=logging_dir_input
|
||||
logging_dir_folder.click(
|
||||
get_folder_path, outputs=logging_dir
|
||||
)
|
||||
with gr.Row():
|
||||
output_name = gr.Textbox(
|
||||
@ -593,25 +596,25 @@ def lora_tab(
|
||||
value='last',
|
||||
interactive=True,
|
||||
)
|
||||
train_data_dir_input.change(
|
||||
train_data_dir.change(
|
||||
remove_doublequote,
|
||||
inputs=[train_data_dir_input],
|
||||
outputs=[train_data_dir_input],
|
||||
inputs=[train_data_dir],
|
||||
outputs=[train_data_dir],
|
||||
)
|
||||
reg_data_dir_input.change(
|
||||
reg_data_dir.change(
|
||||
remove_doublequote,
|
||||
inputs=[reg_data_dir_input],
|
||||
outputs=[reg_data_dir_input],
|
||||
inputs=[reg_data_dir],
|
||||
outputs=[reg_data_dir],
|
||||
)
|
||||
output_dir_input.change(
|
||||
output_dir.change(
|
||||
remove_doublequote,
|
||||
inputs=[output_dir_input],
|
||||
outputs=[output_dir_input],
|
||||
inputs=[output_dir],
|
||||
outputs=[output_dir],
|
||||
)
|
||||
logging_dir_input.change(
|
||||
logging_dir.change(
|
||||
remove_doublequote,
|
||||
inputs=[logging_dir_input],
|
||||
outputs=[logging_dir_input],
|
||||
inputs=[logging_dir],
|
||||
outputs=[logging_dir],
|
||||
)
|
||||
with gr.Tab('Training parameters'):
|
||||
with gr.Row():
|
||||
@ -628,7 +631,7 @@ def lora_tab(
|
||||
outputs=lora_network_weights,
|
||||
)
|
||||
with gr.Row():
|
||||
lr_scheduler_input = gr.Dropdown(
|
||||
lr_scheduler = gr.Dropdown(
|
||||
label='LR Scheduler',
|
||||
choices=[
|
||||
'constant',
|
||||
@ -640,7 +643,7 @@ def lora_tab(
|
||||
],
|
||||
value='cosine',
|
||||
)
|
||||
lr_warmup_input = gr.Textbox(label='LR warmup (% of steps)', value=10)
|
||||
lr_warmup = gr.Textbox(label='LR warmup (% of steps)', value=10)
|
||||
with gr.Row():
|
||||
text_encoder_lr = gr.Textbox(
|
||||
label='Text Encoder learning rate',
|
||||
@ -659,19 +662,19 @@ def lora_tab(
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
train_batch_size_input = gr.Slider(
|
||||
train_batch_size = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=32,
|
||||
label='Train batch size',
|
||||
value=1,
|
||||
step=1,
|
||||
)
|
||||
epoch_input = gr.Textbox(label='Epoch', value=1)
|
||||
save_every_n_epochs_input = gr.Textbox(
|
||||
epoch = gr.Textbox(label='Epoch', value=1)
|
||||
save_every_n_epochs = gr.Textbox(
|
||||
label='Save every N epochs', value=1
|
||||
)
|
||||
with gr.Row():
|
||||
mixed_precision_input = gr.Dropdown(
|
||||
mixed_precision = gr.Dropdown(
|
||||
label='Mixed precision',
|
||||
choices=[
|
||||
'no',
|
||||
@ -680,7 +683,7 @@ def lora_tab(
|
||||
],
|
||||
value='fp16',
|
||||
)
|
||||
save_precision_input = gr.Dropdown(
|
||||
save_precision = gr.Dropdown(
|
||||
label='Save precision',
|
||||
choices=[
|
||||
'float',
|
||||
@ -689,7 +692,7 @@ def lora_tab(
|
||||
],
|
||||
value='fp16',
|
||||
)
|
||||
num_cpu_threads_per_process_input = gr.Slider(
|
||||
num_cpu_threads_per_process = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=os.cpu_count(),
|
||||
step=1,
|
||||
@ -697,18 +700,18 @@ def lora_tab(
|
||||
value=os.cpu_count(),
|
||||
)
|
||||
with gr.Row():
|
||||
seed_input = gr.Textbox(label='Seed', value=1234)
|
||||
max_resolution_input = gr.Textbox(
|
||||
seed = gr.Textbox(label='Seed', value=1234)
|
||||
max_resolution = gr.Textbox(
|
||||
label='Max resolution',
|
||||
value='512,512',
|
||||
placeholder='512,512',
|
||||
)
|
||||
with gr.Row():
|
||||
caption_extention_input = gr.Textbox(
|
||||
caption_extention = gr.Textbox(
|
||||
label='Caption Extension',
|
||||
placeholder='(Optional) Extension for caption files. default: .caption',
|
||||
)
|
||||
stop_text_encoder_training_input = gr.Slider(
|
||||
stop_text_encoder_training = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=100,
|
||||
value=0,
|
||||
@ -716,20 +719,20 @@ def lora_tab(
|
||||
label='Stop text encoder training',
|
||||
)
|
||||
with gr.Row():
|
||||
enable_bucket_input = gr.Checkbox(
|
||||
enable_bucket = gr.Checkbox(
|
||||
label='Enable buckets', value=True
|
||||
)
|
||||
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
|
||||
use_8bit_adam_input = gr.Checkbox(
|
||||
cache_latent = gr.Checkbox(label='Cache latent', value=True)
|
||||
use_8bit_adam = gr.Checkbox(
|
||||
label='Use 8bit adam', value=True
|
||||
)
|
||||
xformers_input = gr.Checkbox(label='Use xformers', value=True)
|
||||
xformers = gr.Checkbox(label='Use xformers', value=True)
|
||||
with gr.Accordion('Advanced Configuration', open=False):
|
||||
with gr.Row():
|
||||
full_fp16_input = gr.Checkbox(
|
||||
full_fp16 = gr.Checkbox(
|
||||
label='Full fp16 training (experimental)', value=False
|
||||
)
|
||||
no_token_padding_input = gr.Checkbox(
|
||||
no_token_padding = gr.Checkbox(
|
||||
label='No token padding', value=False
|
||||
)
|
||||
|
||||
@ -754,7 +757,7 @@ def lora_tab(
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
inputs=[color_aug],
|
||||
outputs=[cache_latent_input],
|
||||
outputs=[cache_latent],
|
||||
)
|
||||
clip_skip = gr.Slider(
|
||||
label='Clip skip', value='1', minimum=1, maximum=12, step=1
|
||||
@ -783,10 +786,10 @@ def lora_tab(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
)
|
||||
gradio_dreambooth_folder_creation_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
train_data_dir_input=train_data_dir,
|
||||
reg_data_dir_input=reg_data_dir,
|
||||
output_dir_input=output_dir,
|
||||
logging_dir_input=logging_dir,
|
||||
)
|
||||
gradio_dataset_balancing_tab()
|
||||
gradio_merge_lora_tab()
|
||||
@ -794,32 +797,32 @@ def lora_tab(
|
||||
button_run = gr.Button('Train model')
|
||||
|
||||
settings_list = [
|
||||
pretrained_model_name_or_path_input,
|
||||
v2_input,
|
||||
v_parameterization_input,
|
||||
logging_dir_input,
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
max_resolution_input,
|
||||
lr_scheduler_input,
|
||||
lr_warmup_input,
|
||||
train_batch_size_input,
|
||||
epoch_input,
|
||||
save_every_n_epochs_input,
|
||||
mixed_precision_input,
|
||||
save_precision_input,
|
||||
seed_input,
|
||||
num_cpu_threads_per_process_input,
|
||||
cache_latent_input,
|
||||
caption_extention_input,
|
||||
enable_bucket_input,
|
||||
pretrained_model_name_or_path,
|
||||
v2,
|
||||
v_parameterization,
|
||||
logging_dir,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
output_dir,
|
||||
max_resolution,
|
||||
lr_scheduler,
|
||||
lr_warmup,
|
||||
train_batch_size,
|
||||
epoch,
|
||||
save_every_n_epochs,
|
||||
mixed_precision,
|
||||
save_precision,
|
||||
seed,
|
||||
num_cpu_threads_per_process,
|
||||
cache_latent,
|
||||
caption_extention,
|
||||
enable_bucket,
|
||||
gradient_checkpointing,
|
||||
full_fp16_input,
|
||||
no_token_padding_input,
|
||||
stop_text_encoder_training_input,
|
||||
use_8bit_adam_input,
|
||||
xformers_input,
|
||||
full_fp16,
|
||||
no_token_padding,
|
||||
stop_text_encoder_training,
|
||||
use_8bit_adam,
|
||||
xformers,
|
||||
save_model_as_dropdown,
|
||||
shuffle_caption,
|
||||
save_state,
|
||||
@ -835,6 +838,7 @@ def lora_tab(
|
||||
gradient_accumulation_steps,
|
||||
mem_eff_attn,
|
||||
output_name,
|
||||
model_list,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
@ -861,10 +865,10 @@ def lora_tab(
|
||||
)
|
||||
|
||||
return (
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
train_data_dir,
|
||||
reg_data_dir,
|
||||
output_dir,
|
||||
logging_dir,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user