Merge pull request #147 from bmaltais/dev

v20.7.2
This commit is contained in:
bmaltais 2023-02-11 12:00:17 -05:00 committed by GitHub
commit a008c62893
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1017 additions and 377 deletions

View File

@ -116,7 +116,7 @@ accelerate configの質問には以下のように答えてください。bf1
cd sd-scripts
git pull
.\venv\Scripts\activate
pip install --upgrade -r requirements.txt
pip install --use-pep517 --upgrade -r requirements.txt
```
コマンドが成功すれば新しいバージョンが使用できます。

View File

@ -163,6 +163,14 @@ This will store your a backup file with your current locally installed pip packa
## Change History
* 2023/02/11 (v20.7.2):
- ``lora_interrogator.py`` is added in ``networks`` folder. See ``python networks\lora_interrogator.py -h`` for usage.
- For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate.
- Batch size can be large (like 64 or 128).
- ``train_textual_inversion.py`` now supports multiple init words.
- Following feature is reverted to be the same as before. Sorry for confusion:
> Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
- Add new tool to sort, group and average crop image in a dataset
* 2023/02/09 (v20.7.1)
- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
- ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).

View File

@ -435,40 +435,6 @@ def train_model(
save_inference_file(output_dir, v2, v_parameterization, output_name)
def UI(username, password):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Dreambooth'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
def dreambooth_tab(
train_data_dir=gr.Textbox(),
reg_data_dir=gr.Textbox(),
@ -735,6 +701,44 @@ def dreambooth_tab(
)
def UI(**kwargs):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Dreambooth'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = dreambooth_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
@ -744,7 +748,11 @@ if __name__ == '__main__':
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password)
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)

View File

@ -431,30 +431,6 @@ def remove_doublequote(file_path):
return file_path
def UI(username, password):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Finetune'):
finetune_tab()
with gr.Tab('Utilities'):
utilities_tab(enable_dreambooth_tab=False)
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
def finetune_tab():
dummy_ft_true = gr.Label(value=True, visible=False)
dummy_ft_false = gr.Label(value=False, visible=False)
@ -708,6 +684,35 @@ def finetune_tab():
)
def UI(**kwargs):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Finetune'):
finetune_tab()
with gr.Tab('Utilities'):
utilities_tab(enable_dreambooth_tab=False)
# Show the interface
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
@ -717,7 +722,11 @@ if __name__ == '__main__':
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password)
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)

View File

@ -1,10 +1,6 @@
@echo off
set VENV_DIR=.\venv
set PYTHON=python
call %VENV_DIR%\Scripts\activate.bat
%PYTHON% kohya_gui.py
call venv\Scripts\activate.bat
python.exe kohya_gui.py %*
pause

View File

@ -1,2 +1,2 @@
.\venv\Scripts\activate
python.exe kohya_gui.py
python.exe kohya_gui.py $args

View File

@ -10,8 +10,7 @@ from library.merge_lora_gui import gradio_merge_lora_tab
from lora_gui import lora_tab
def UI(username, password, inbrowser, server_port):
def UI(**kwargs):
css = ''
if os.path.exists('./style.css'):
@ -47,13 +46,18 @@ def UI(username, password, inbrowser, server_port):
gradio_merge_lora_tab()
# Show the interface
kwargs = {}
if username:
kwargs["auth"] = (username, password)
launch_kwargs = {}
username = kwargs.get('username')
password = kwargs.get('password')
server_port = kwargs.get('server_port', 0)
inbrowser = kwargs.get('inbrowser', False)
if username and password:
launch_kwargs["auth"] = (username, password)
if server_port > 0:
kwargs["server_port"] = server_port
kwargs["inbrowser"] = inbrowser
interface.launch(**kwargs)
launch_kwargs["server_port"] = server_port
if inbrowser:
launch_kwargs["inbrowser"] = inbrowser
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)

View File

@ -568,9 +568,11 @@ def gradio_advanced_training():
label="Dropout caption every n epochs",
value=0
)
caption_dropout_rate = gr.Number(
caption_dropout_rate = gr.Slider(
label="Rate of caption dropout",
value=0
value=0,
minimum=0,
maximum=1
)
with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False)

View File

@ -226,6 +226,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None
self.tag_dropout_rate: float = 0
# augmentation
flip_p = 0.5 if flip_aug else 0.0
@ -284,7 +285,7 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out:
caption = ""
else:
if self.shuffle_caption:
if self.shuffle_caption or self.tag_dropout_rate > 0:
def dropout_tags(tokens):
if self.tag_dropout_rate <= 0:
return tokens
@ -296,13 +297,18 @@ class BaseDataset(torch.utils.data.Dataset):
tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None:
random.shuffle(tokens)
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens
@ -426,17 +432,25 @@ class BaseDataset(torch.utils.data.Dataset):
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List(BucketBatchIndex) = []
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
# bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
# ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
# そのためバッチサイズを画像種類までに制限する
# ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない
# TODO 正則化画像をepochまたがりで利用する仕組み
num_of_image_types = len(set(bucket))
bucket_batch_size = min(self.batch_size, num_of_image_types)
batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
# print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
#  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
#
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
# # そのためバッチサイズを画像種類までに制限する
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない
# # TO DO 正則化画像をepochまたがりで利用する仕組み
# num_of_image_types = len(set(bucket))
# bucket_batch_size = min(self.batch_size, num_of_image_types)
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
# for batch_index in range(batch_count):
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
# ↑ここまで
self.shuffle_buckets()
self._length = len(self.buckets_indices)
@ -842,6 +856,7 @@ class FineTuningDataset(BaseDataset):
self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0
# TODO do not record tag freq when no tag
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}

View File

@ -36,7 +36,7 @@ def utilities_tab(
)
def UI(username, password):
def UI(**kwargs):
css = ''
if os.path.exists('./style.css'):
@ -50,11 +50,16 @@ def UI(username, password):
utilities_tab()
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
@ -65,7 +70,11 @@ if __name__ == '__main__':
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password)
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)

View File

@ -495,40 +495,6 @@ def train_model(
save_inference_file(output_dir, v2, v_parameterization, output_name)
def UI(username, password):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('LoRA'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = lora_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
def lora_tab(
train_data_dir_input=gr.Textbox(),
reg_data_dir_input=gr.Textbox(),
@ -644,7 +610,7 @@ def lora_tab(
caption_extension,
cache_latents,
) = gradio_training(
learning_rate_value='1e-5',
learning_rate_value='0.0001',
lr_scheduler_value='cosine',
lr_warmup_value='10',
)
@ -656,7 +622,7 @@ def lora_tab(
)
unet_lr = gr.Textbox(
label='Unet learning rate',
value='1e-3',
value='0.0001',
placeholder='Optional',
)
network_dim = gr.Slider(
@ -845,6 +811,45 @@ def lora_tab(
)
def UI(**kwargs):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('LoRA'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = lora_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
@ -854,7 +859,11 @@ if __name__ == '__main__':
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password)
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)

View File

@ -0,0 +1,122 @@
from tqdm import tqdm
from library import model_util
import argparse
from transformers import CLIPTokenizer
import torch
import library.model_util as model_util
import lora
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def interrogate(args):
# いろいろ準備する
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
print(f"loading LoRA: {args.model}")
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
has_te_weight = False
for key in network.weights_sd.keys():
if 'lora_te' in key:
has_te_weight = True
break
if not has_te_weight:
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
return
del vae
print("loading tokenizer")
if args.v2:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
text_encoder.to(DEVICE)
text_encoder.eval()
unet.to(DEVICE)
unet.eval() # U-Netは呼び出さないので不要だけど
# トークンをひとつひとつ当たっていく
token_id_start = 0
token_id_end = max(tokenizer.all_special_ids)
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
def get_all_embeddings(text_encoder):
embs = []
with torch.no_grad():
for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
batch = []
for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
# tokens = [tid] # こちらは結果がいまひとつ
batch.append(tokens)
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
# clip skip対応
batch = torch.tensor(batch).to(DEVICE)
if args.clip_skip is None:
encoder_hidden_states = text_encoder(batch)[0]
else:
enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.to("cpu")
embs.extend(encoder_hidden_states)
return torch.stack(embs)
print("get original text encoder embeddings.")
orig_embs = get_all_embeddings(text_encoder)
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
network.to(DEVICE)
network.eval()
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
print("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder)
# 比べる:とりあえず単純に差分の絶対値で
print("comparing...")
diffs = {}
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
diff = torch.mean(torch.abs(orig_emb - lora_emb))
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
diff = float(diff.detach().to('cpu').numpy())
diffs[token_id_start + i] = diff
diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
# 結果を表示する
print("top 100:")
for i, (token, diff) in enumerate(diffs_sorted[:100]):
# if diff < 1e-6:
# break
string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None,
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--batch_size", type=int, default=16,
help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
args = parser.parse_args()
interrogate(args)

View File

@ -1,5 +1,5 @@
import math
import argparse
import os
import torch
@ -85,43 +85,76 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
# conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {}
base_alphas = {} # alpha for merged model
base_dims = {}
alpha = None
dim = None
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
# get alpha and dim
alphas = {} # alpha for current model
dims = {} # dims for current model
for key in lora_sd.keys():
if 'alpha' in key:
lora_module_name = key[:key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
elif "lora_down" in key:
lora_module_name = key[:key.rfind(".lora_down")]
dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim
if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim
for lora_module_name in dims.keys():
if lora_module_name not in alphas:
alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
else:
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key]
continue
lora_module_name = key[:key.rfind(".lora_")]
base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
else:
if "lora_down" in key:
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
merged_sd[key] = lora_sd[key] * scale
# set alpha to sd
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
print(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
return merged_sd, dim, alpha
return merged_sd
def merge(args):
@ -152,7 +185,7 @@ def merge(args):
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)

179
networks/merge_lora_old.py Normal file
View File

@ -0,0 +1,179 @@
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
import library.model_util as model_util
import lora
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
text_encoder.to(merge_dtype)
unet.to(merge_dtype)
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder, unet]):
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
# conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {}
alpha = None
dim = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
else:
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key]
else:
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
else:
if "lora_down" in key:
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
return merged_sd, dim, alpha
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
args = parser.parse_args()
merge(args)

View File

@ -5,148 +5,169 @@
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location='cpu')
metadata = None
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype):
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
if model_util.is_safetensors(file_name):
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
print("Loading Model...")
lora_sd = load_state_dict(model, merge_dtype)
def resize_lora_model(lora_sd, new_rank, save_dtype, device):
network_alpha = None
network_dim = None
network_alpha = None
network_dim = None
CLAMP_QUANTILE = 0.99
CLAMP_QUANTILE = 0.99
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
scale = network_alpha/network_dim
new_alpha = float(scale*new_rank) # calculate new alpha from scale
scale = network_alpha/network_dim
new_alpha = float(scale*new_rank) # calculate new alpha from scale
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}")
lora_down_weight = None
lora_up_weight = None
lora_down_weight = None
lora_up_weight = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
print("resizing lora...")
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
block_down_name = key.split(".")[0]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
print("resizing lora...")
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
block_down_name = key.split(".")[0]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
if (block_down_name == block_up_name) and weights_loaded:
if (block_down_name == block_up_name) and weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4)
conv2d = (len(lora_down_weight.size()) == 4)
if conv2d:
lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze()
if conv2d:
lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze()
if args.device:
org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)
if device:
org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
U, S, Vh = torch.linalg.svd(full_weight_matrix)
U, S, Vh = torch.linalg.svd(full_weight_matrix)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3)
if args.device:
U = U.to(org_device)
Vh = Vh.to(org_device)
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
if conv2d:
U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
if args.device:
U = U.to(org_device)
Vh = Vh.to(org_device)
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
print("resizing complete")
return o_lora_sd, network_dim, new_alpha
print("resizing complete")
return o_lora_sd
def resize(args):
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype)
print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
print("resizing rank...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device)
# update metadata
if metadata is None:
metadata = {}
comment = metadata.get("ss_training_comment", "")
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
if __name__ == '__main__':

164
networks/svd_merge_lora.py Normal file
View File

@ -0,0 +1,164 @@
import math
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import library.model_util as model_util
import lora
CLAMP_QUANTILE = 0.99
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
# merge
print(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if 'lora_down' not in key:
continue
lora_module_name = key[:key.rfind(".lora_down")]
down_weight = lora_sd[key]
network_dim = down_weight.size()[0]
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
in_dim = down_weight.size()[1]
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
# merge to weight
if device:
up_weight = up_weight.to(device)
down_weight = down_weight.to(device)
# W <- W + U * D
scale = (alpha / network_dim)
if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
merged_sd[lora_module_name] = weight
# extract from merged weights
print("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
conv2d = (len(mat.size()) == 4)
if conv2d:
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
up_weight = U
down_weight = Vh
if conv2d:
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
return merged_lora_sd
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
args = parser.parse_args()
merge(args)

View File

@ -481,40 +481,6 @@ def train_model(
save_inference_file(output_dir, v2, v_parameterization, output_name)
def UI(username, password):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Dreambooth TI'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = ti_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
def ti_tab(
train_data_dir=gr.Textbox(),
reg_data_dir=gr.Textbox(),
@ -823,6 +789,45 @@ def ti_tab(
)
def UI(**kwargs):
css = ''
if os.path.exists('./style.css'):
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
print('Load CSS...')
css += file.read() + '\n'
interface = gr.Blocks(css=css)
with interface:
with gr.Tab('Dreambooth TI'):
(
train_data_dir_input,
reg_data_dir_input,
output_dir_input,
logging_dir_input,
) = ti_tab()
with gr.Tab('Utilities'):
utilities_tab(
train_data_dir_input=train_data_dir_input,
reg_data_dir_input=reg_data_dir_input,
output_dir_input=output_dir_input,
logging_dir_input=logging_dir_input,
enable_copy_info_button=True,
)
# Show the interface
launch_kwargs={}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
@ -832,7 +837,11 @@ if __name__ == '__main__':
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password)
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)

View File

@ -0,0 +1,121 @@
# This code sorts a collection of images in a given directory by their aspect ratio, groups
# them into batches of a given size, crops each image in a batch to the average aspect ratio
# of that batch, and saves the cropped images in a specified directory. The user provides
# the paths to the input directory and the output directory, as well as the desired batch
# size. The program drops any images that do not fit exactly into the batches.
import os
import cv2
import argparse
def aspect_ratio(img_path):
"""Return aspect ratio of an image"""
image = cv2.imread(img_path)
height, width = image.shape[:2]
aspect_ratio = float(width) / float(height)
return aspect_ratio
def sort_images_by_aspect_ratio(path):
"""Sort all images in a folder by aspect ratio"""
images = []
for filename in os.listdir(path):
if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):
img_path = os.path.join(path, filename)
images.append((img_path, aspect_ratio(img_path)))
# sort the list of tuples based on the aspect ratio
sorted_images = sorted(images, key=lambda x: x[1])
return sorted_images
def create_groups(sorted_images, n_groups):
"""Create n groups from sorted list of images"""
n = len(sorted_images)
size = n // n_groups
groups = [sorted_images[i * size : (i + 1) * size] for i in range(n_groups - 1)]
groups.append(sorted_images[(n_groups - 1) * size:])
return groups
def average_aspect_ratio(group):
"""Calculate average aspect ratio for a group"""
aspect_ratios = [aspect_ratio for _, aspect_ratio in group]
avg_aspect_ratio = sum(aspect_ratios) / len(aspect_ratios)
return avg_aspect_ratio
def center_crop_image(image, target_aspect_ratio):
height, width = image.shape[:2]
current_aspect_ratio = float(width) / float(height)
if current_aspect_ratio == target_aspect_ratio:
return image
if current_aspect_ratio > target_aspect_ratio:
new_width = int(target_aspect_ratio * height)
x_start = (width - new_width) // 2
cropped_image = image[:, x_start:x_start+new_width]
else:
new_height = int(width / target_aspect_ratio)
y_start = (height - new_height) // 2
cropped_image = image[y_start:y_start+new_height, :]
return cropped_image
def save_cropped_images(group, folder_name, group_number, avg_aspect_ratio):
if not os.path.exists(folder_name):
os.makedirs(folder_name)
# get the smallest size of the images
small_height = 0
small_width = 0
smallest_res = 100000000
for i, image in enumerate(group):
img_path, aspect_ratio = image
image = cv2.imread(img_path)
cropped_image = center_crop_image(image, avg_aspect_ratio)
height, width = cropped_image.shape[:2]
if smallest_res > height * width:
small_height = height
small_width = width
smallest_res = height * width
# resize all images to the smallest resolution of the images in the group
for i, image in enumerate(group):
img_path, aspect_ratio = image
image = cv2.imread(img_path)
cropped_image = center_crop_image(image, avg_aspect_ratio)
resized_image = cv2.resize(cropped_image, (small_width, small_height))
save_path = os.path.join(folder_name, "group_{}_{}.jpg".format(group_number, i))
cv2.imwrite(save_path, resized_image)
def main():
parser = argparse.ArgumentParser(description='Sort images and crop them based on aspect ratio')
parser.add_argument('--path', type=str, help='Path to the directory containing images', required=True)
parser.add_argument('--dst_path', type=str, help='Path to the directory to save the cropped images', required=True)
parser.add_argument('--batch_size', type=int, help='Size of the batches to create', required=True)
args = parser.parse_args()
sorted_images = sort_images_by_aspect_ratio(args.path)
total_images = len(sorted_images)
print(f'Total images: {total_images}')
group_size = total_images // args.batch_size
print(f'Train batch size: {args.batch_size}, image group size: {group_size}')
remainder = total_images % args.batch_size
if remainder != 0:
print(f'Dropping {remainder} images that do not fit in groups...')
sorted_images = sorted_images[:-remainder]
total_images = len(sorted_images)
group_size = total_images // args.batch_size
print('Creating groups...')
groups = create_groups(sorted_images, group_size)
print('Saving cropped and resize images...')
for i, group in enumerate(groups):
avg_aspect_ratio = average_aspect_ratio(group)
save_cropped_images(group, args.dst_path, i+1, avg_aspect_ratio)
if __name__ == '__main__':
main()

View File

@ -4,6 +4,8 @@ import cv2
import argparse
import shutil
import math
from PIL import Image
import numpy as np
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
@ -35,7 +37,11 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
continue
# Load image
img = cv2.imread(os.path.join(src_img_folder, filename))
# img = cv2.imread(os.path.join(src_img_folder, filename))
image = Image.open(os.path.join(src_img_folder, filename))
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
base, _ = os.path.splitext(filename)
for max_resolution in max_resolutions:
@ -72,7 +78,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
# Save resized image in dst_img_folder
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
# cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
image = Image.fromarray(img)
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")

View File

@ -1,78 +0,0 @@
import os
import cv2
import argparse
import shutil
import math
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=1, caption_extension=''):
# Split the max_resolution string by "," and strip any whitespaces
max_resolutions = [res.strip() for res in max_resolution.split(',')]
# Create destination folder if it does not exist
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Iterate through all files in src_img_folder
for filename in os.listdir(src_img_folder):
# Check if the image is png, jpg or webp
if not filename.endswith(('.png', '.jpg', '.webp')):
# Copy the file to the destination folder if not png, jpg or webp
# shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
continue
# Load image
img = cv2.imread(os.path.join(src_img_folder, filename))
for max_resolution in max_resolutions:
# Calculate max_pixels from max_resolution string
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Calculate current number of pixels
current_pixels = img.shape[0] * img.shape[1]
# Check if the image needs resizing
if current_pixels > max_pixels:
# Calculate scaling factor
scale_factor = max_pixels / current_pixels
# Calculate new dimensions
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image using area interpolation (best when downsampling)
img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
# Calculate the new height and width that are divisible by divisible_by
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
# Center crop the image to the calculated dimensions
y = int((img.shape[0] - new_height) / 2)
x = int((img.shape[1] - new_width) / 2)
img = img[y:y + new_height, x:x + new_width]
# Split filename into base and extension
base, ext = os.path.splitext(filename)
new_filename = base + '+' + max_resolution + '.jpg'
# copy caption file with right name if one exist
if os.path.exists(os.path.join(src_img_folder, base + caption_extension)):
shutil.copy(os.path.join(src_img_folder, base + caption_extension), os.path.join(dst_img_folder, new_filename + caption_extension))
# Save resized image in dst_img_folder
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
def main():
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution(s)')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
parser.add_argument('--max_resolution', type=str, help='Maximum resolution(s) in the format "512x512,448x448,384x384, etc, etc"', default="512x512,448x448,384x384")
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=1)
parser.add_argument('--caption_extension', type=str, help='Extension of caption files to copy with resized images"', default=".txt")
args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, args.divisible_by, args.caption_extension)
if __name__ == '__main__':
main()

View File

@ -98,12 +98,12 @@ def train(args):
# Convert the init_word to token_id
if args.init_word is not None:
init_token_id = tokenizer.encode(args.init_word, add_special_tokens=False)
assert len(
init_token_id) == 1, f"init word {args.init_word} is not converted to single token / 初期化単語が二つ以上のトークンに変換されます。別の単語を使ってください"
init_token_id = init_token_id[0]
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}")
else:
init_token_id = None
init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
@ -120,9 +120,9 @@ def train(args):
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_id is not None:
for token_id in token_ids:
token_embeds[token_id] = token_embeds[init_token_id]
if init_token_ids is not None:
for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
@ -492,7 +492,7 @@ if __name__ == '__main__':
parser.add_argument("--token_string", type=str, default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること")
parser.add_argument("--init_word", type=str, default=None,
help="word to initialize vector / ベクトルを初期化に使用する単語、tokenizerで一語になること")
help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument("--use_object_template", action='store_true',
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する")
parser.add_argument("--use_style_template", action='store_true',