commit
a008c62893
@ -116,7 +116,7 @@ accelerate configの質問には以下のように答えてください。(bf1
|
||||
cd sd-scripts
|
||||
git pull
|
||||
.\venv\Scripts\activate
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install --use-pep517 --upgrade -r requirements.txt
|
||||
```
|
||||
|
||||
コマンドが成功すれば新しいバージョンが使用できます。
|
||||
|
@ -163,6 +163,14 @@ This will store your a backup file with your current locally installed pip packa
|
||||
|
||||
## Change History
|
||||
|
||||
* 2023/02/11 (v20.7.2):
|
||||
- ``lora_interrogator.py`` is added in ``networks`` folder. See ``python networks\lora_interrogator.py -h`` for usage.
|
||||
- For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate.
|
||||
- Batch size can be large (like 64 or 128).
|
||||
- ``train_textual_inversion.py`` now supports multiple init words.
|
||||
- Following feature is reverted to be the same as before. Sorry for confusion:
|
||||
> Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
|
||||
- Add new tool to sort, group and average crop image in a dataset
|
||||
* 2023/02/09 (v20.7.1)
|
||||
- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
|
||||
- ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).
|
||||
|
@ -435,40 +435,6 @@ def train_model(
|
||||
save_inference_file(output_dir, v2, v_parameterization, output_name)
|
||||
|
||||
|
||||
def UI(username, password):
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Dreambooth'):
|
||||
(
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = dreambooth_tab()
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
enable_copy_info_button=True,
|
||||
)
|
||||
|
||||
# Show the interface
|
||||
if not username == '':
|
||||
interface.launch(auth=(username, password))
|
||||
else:
|
||||
interface.launch()
|
||||
|
||||
|
||||
def dreambooth_tab(
|
||||
train_data_dir=gr.Textbox(),
|
||||
reg_data_dir=gr.Textbox(),
|
||||
@ -735,6 +701,44 @@ def dreambooth_tab(
|
||||
)
|
||||
|
||||
|
||||
def UI(**kwargs):
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Dreambooth'):
|
||||
(
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = dreambooth_tab()
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
enable_copy_info_button=True,
|
||||
)
|
||||
|
||||
# Show the interface
|
||||
launch_kwargs={}
|
||||
if not kwargs.get('username', None) == '':
|
||||
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
|
||||
if kwargs.get('server_port', 0) > 0:
|
||||
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
|
||||
if kwargs.get('inbrowser', False):
|
||||
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
|
||||
print(launch_kwargs)
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -744,7 +748,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
'--password', type=str, default='', help='Password for authentication'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
||||
)
|
||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
UI(username=args.username, password=args.password)
|
||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
||||
|
@ -431,30 +431,6 @@ def remove_doublequote(file_path):
|
||||
return file_path
|
||||
|
||||
|
||||
def UI(username, password):
|
||||
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Finetune'):
|
||||
finetune_tab()
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(enable_dreambooth_tab=False)
|
||||
|
||||
# Show the interface
|
||||
if not username == '':
|
||||
interface.launch(auth=(username, password))
|
||||
else:
|
||||
interface.launch()
|
||||
|
||||
|
||||
def finetune_tab():
|
||||
dummy_ft_true = gr.Label(value=True, visible=False)
|
||||
dummy_ft_false = gr.Label(value=False, visible=False)
|
||||
@ -708,6 +684,35 @@ def finetune_tab():
|
||||
)
|
||||
|
||||
|
||||
def UI(**kwargs):
|
||||
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Finetune'):
|
||||
finetune_tab()
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(enable_dreambooth_tab=False)
|
||||
|
||||
# Show the interface
|
||||
launch_kwargs={}
|
||||
if not kwargs.get('username', None) == '':
|
||||
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
|
||||
if kwargs.get('server_port', 0) > 0:
|
||||
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
|
||||
if kwargs.get('inbrowser', False):
|
||||
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
|
||||
print(launch_kwargs)
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -717,7 +722,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
'--password', type=str, default='', help='Password for authentication'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
||||
)
|
||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
UI(username=args.username, password=args.password)
|
||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
||||
|
8
gui.bat
8
gui.bat
@ -1,10 +1,6 @@
|
||||
@echo off
|
||||
|
||||
set VENV_DIR=.\venv
|
||||
set PYTHON=python
|
||||
|
||||
call %VENV_DIR%\Scripts\activate.bat
|
||||
|
||||
%PYTHON% kohya_gui.py
|
||||
call venv\Scripts\activate.bat
|
||||
python.exe kohya_gui.py %*
|
||||
|
||||
pause
|
2
gui.ps1
2
gui.ps1
@ -1,2 +1,2 @@
|
||||
.\venv\Scripts\activate
|
||||
python.exe kohya_gui.py
|
||||
python.exe kohya_gui.py $args
|
20
kohya_gui.py
20
kohya_gui.py
@ -10,8 +10,7 @@ from library.merge_lora_gui import gradio_merge_lora_tab
|
||||
from lora_gui import lora_tab
|
||||
|
||||
|
||||
def UI(username, password, inbrowser, server_port):
|
||||
|
||||
def UI(**kwargs):
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
@ -47,13 +46,18 @@ def UI(username, password, inbrowser, server_port):
|
||||
gradio_merge_lora_tab()
|
||||
|
||||
# Show the interface
|
||||
kwargs = {}
|
||||
if username:
|
||||
kwargs["auth"] = (username, password)
|
||||
launch_kwargs = {}
|
||||
username = kwargs.get('username')
|
||||
password = kwargs.get('password')
|
||||
server_port = kwargs.get('server_port', 0)
|
||||
inbrowser = kwargs.get('inbrowser', False)
|
||||
if username and password:
|
||||
launch_kwargs["auth"] = (username, password)
|
||||
if server_port > 0:
|
||||
kwargs["server_port"] = server_port
|
||||
kwargs["inbrowser"] = inbrowser
|
||||
interface.launch(**kwargs)
|
||||
launch_kwargs["server_port"] = server_port
|
||||
if inbrowser:
|
||||
launch_kwargs["inbrowser"] = inbrowser
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
|
@ -568,9 +568,11 @@ def gradio_advanced_training():
|
||||
label="Dropout caption every n epochs",
|
||||
value=0
|
||||
)
|
||||
caption_dropout_rate = gr.Number(
|
||||
caption_dropout_rate = gr.Slider(
|
||||
label="Rate of caption dropout",
|
||||
value=0
|
||||
value=0,
|
||||
minimum=0,
|
||||
maximum=1
|
||||
)
|
||||
with gr.Row():
|
||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||
|
@ -226,6 +226,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||
self.dropout_rate: float = 0
|
||||
self.dropout_every_n_epochs: int = None
|
||||
self.tag_dropout_rate: float = 0
|
||||
|
||||
# augmentation
|
||||
flip_p = 0.5 if flip_aug else 0.0
|
||||
@ -284,7 +285,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
if self.shuffle_caption:
|
||||
if self.shuffle_caption or self.tag_dropout_rate > 0:
|
||||
def dropout_tags(tokens):
|
||||
if self.tag_dropout_rate <= 0:
|
||||
return tokens
|
||||
@ -296,13 +297,18 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if self.shuffle_keep_tokens is None:
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
else:
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||
tokens = tokens[self.shuffle_keep_tokens:]
|
||||
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
|
||||
tokens = keep_tokens + tokens
|
||||
@ -426,17 +432,25 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
||||
self.buckets_indices: List(BucketBatchIndex) = []
|
||||
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
||||
# bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# そのためバッチサイズを画像種類までに制限する
|
||||
# ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# TODO 正則化画像をepochまたがりで利用する仕組み
|
||||
num_of_image_types = len(set(bucket))
|
||||
bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
||||
|
||||
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
||||
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
||||
#
|
||||
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# # そのためバッチサイズを画像種類までに制限する
|
||||
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
||||
# num_of_image_types = len(set(bucket))
|
||||
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
# for batch_index in range(batch_count):
|
||||
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
# ↑ここまで
|
||||
|
||||
self.shuffle_buckets()
|
||||
self._length = len(self.buckets_indices)
|
||||
@ -842,6 +856,7 @@ class FineTuningDataset(BaseDataset):
|
||||
self.num_train_images = len(metadata) * dataset_repeats
|
||||
self.num_reg_images = 0
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
||||
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||
|
||||
|
@ -36,7 +36,7 @@ def utilities_tab(
|
||||
)
|
||||
|
||||
|
||||
def UI(username, password):
|
||||
def UI(**kwargs):
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
@ -50,10 +50,15 @@ def UI(username, password):
|
||||
utilities_tab()
|
||||
|
||||
# Show the interface
|
||||
if not username == '':
|
||||
interface.launch(auth=(username, password))
|
||||
else:
|
||||
interface.launch()
|
||||
launch_kwargs={}
|
||||
if not kwargs.get('username', None) == '':
|
||||
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
|
||||
if kwargs.get('server_port', 0) > 0:
|
||||
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
|
||||
if kwargs.get('inbrowser', False):
|
||||
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
|
||||
print(launch_kwargs)
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -65,7 +70,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
'--password', type=str, default='', help='Password for authentication'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
||||
)
|
||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
UI(username=args.username, password=args.password)
|
||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
||||
|
83
lora_gui.py
83
lora_gui.py
@ -495,40 +495,6 @@ def train_model(
|
||||
save_inference_file(output_dir, v2, v_parameterization, output_name)
|
||||
|
||||
|
||||
def UI(username, password):
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('LoRA'):
|
||||
(
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = lora_tab()
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
enable_copy_info_button=True,
|
||||
)
|
||||
|
||||
# Show the interface
|
||||
if not username == '':
|
||||
interface.launch(auth=(username, password))
|
||||
else:
|
||||
interface.launch()
|
||||
|
||||
|
||||
def lora_tab(
|
||||
train_data_dir_input=gr.Textbox(),
|
||||
reg_data_dir_input=gr.Textbox(),
|
||||
@ -644,7 +610,7 @@ def lora_tab(
|
||||
caption_extension,
|
||||
cache_latents,
|
||||
) = gradio_training(
|
||||
learning_rate_value='1e-5',
|
||||
learning_rate_value='0.0001',
|
||||
lr_scheduler_value='cosine',
|
||||
lr_warmup_value='10',
|
||||
)
|
||||
@ -656,7 +622,7 @@ def lora_tab(
|
||||
)
|
||||
unet_lr = gr.Textbox(
|
||||
label='Unet learning rate',
|
||||
value='1e-3',
|
||||
value='0.0001',
|
||||
placeholder='Optional',
|
||||
)
|
||||
network_dim = gr.Slider(
|
||||
@ -845,6 +811,45 @@ def lora_tab(
|
||||
)
|
||||
|
||||
|
||||
def UI(**kwargs):
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('LoRA'):
|
||||
(
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = lora_tab()
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
enable_copy_info_button=True,
|
||||
)
|
||||
|
||||
# Show the interface
|
||||
launch_kwargs={}
|
||||
if not kwargs.get('username', None) == '':
|
||||
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
|
||||
if kwargs.get('server_port', 0) > 0:
|
||||
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
|
||||
if kwargs.get('inbrowser', False):
|
||||
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
|
||||
print(launch_kwargs)
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -854,7 +859,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
'--password', type=str, default='', help='Password for authentication'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
||||
)
|
||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
UI(username=args.username, password=args.password)
|
||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
||||
|
122
networks/lora_interrogator.py
Normal file
122
networks/lora_interrogator.py
Normal file
@ -0,0 +1,122 @@
|
||||
|
||||
|
||||
from tqdm import tqdm
|
||||
from library import model_util
|
||||
import argparse
|
||||
from transformers import CLIPTokenizer
|
||||
import torch
|
||||
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def interrogate(args):
|
||||
# いろいろ準備する
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
|
||||
print(f"loading LoRA: {args.model}")
|
||||
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||
|
||||
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
||||
has_te_weight = False
|
||||
for key in network.weights_sd.keys():
|
||||
if 'lora_te' in key:
|
||||
has_te_weight = True
|
||||
break
|
||||
if not has_te_weight:
|
||||
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
||||
return
|
||||
del vae
|
||||
|
||||
print("loading tokenizer")
|
||||
if args.v2:
|
||||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
||||
else:
|
||||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
||||
|
||||
text_encoder.to(DEVICE)
|
||||
text_encoder.eval()
|
||||
unet.to(DEVICE)
|
||||
unet.eval() # U-Netは呼び出さないので不要だけど
|
||||
|
||||
# トークンをひとつひとつ当たっていく
|
||||
token_id_start = 0
|
||||
token_id_end = max(tokenizer.all_special_ids)
|
||||
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
||||
|
||||
def get_all_embeddings(text_encoder):
|
||||
embs = []
|
||||
with torch.no_grad():
|
||||
for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
|
||||
batch = []
|
||||
for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
|
||||
tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
|
||||
# tokens = [tid] # こちらは結果がいまひとつ
|
||||
batch.append(tokens)
|
||||
|
||||
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
|
||||
# clip skip対応
|
||||
batch = torch.tensor(batch).to(DEVICE)
|
||||
if args.clip_skip is None:
|
||||
encoder_hidden_states = text_encoder(batch)[0]
|
||||
else:
|
||||
enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.to("cpu")
|
||||
|
||||
embs.extend(encoder_hidden_states)
|
||||
return torch.stack(embs)
|
||||
|
||||
print("get original text encoder embeddings.")
|
||||
orig_embs = get_all_embeddings(text_encoder)
|
||||
|
||||
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
||||
network.to(DEVICE)
|
||||
network.eval()
|
||||
|
||||
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||||
print("get text encoder embeddings with lora.")
|
||||
lora_embs = get_all_embeddings(text_encoder)
|
||||
|
||||
# 比べる:とりあえず単純に差分の絶対値で
|
||||
print("comparing...")
|
||||
diffs = {}
|
||||
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
|
||||
diff = torch.mean(torch.abs(orig_emb - lora_emb))
|
||||
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
|
||||
diff = float(diff.detach().to('cpu').numpy())
|
||||
diffs[token_id_start + i] = diff
|
||||
|
||||
diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
|
||||
|
||||
# 結果を表示する
|
||||
print("top 100:")
|
||||
for i, (token, diff) in enumerate(diffs_sorted[:100]):
|
||||
# if diff < 1e-6:
|
||||
# break
|
||||
string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
|
||||
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
parser.add_argument("--sd_model", type=str, default=None,
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--model", type=str, default=None,
|
||||
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--batch_size", type=int, default=16,
|
||||
help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
|
||||
parser.add_argument("--clip_skip", type=int, default=None,
|
||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||
|
||||
args = parser.parse_args()
|
||||
interrogate(args)
|
@ -1,5 +1,5 @@
|
||||
|
||||
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
@ -85,43 +85,76 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
else:
|
||||
# conv2d
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||
).unsqueeze(2).unsqueeze(3) * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, merge_dtype):
|
||||
merged_sd = {}
|
||||
base_alphas = {} # alpha for merged model
|
||||
base_dims = {}
|
||||
|
||||
alpha = None
|
||||
dim = None
|
||||
merged_sd = {}
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
# get alpha and dim
|
||||
alphas = {} # alpha for current model
|
||||
dims = {} # dims for current model
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
lora_module_name = key[:key.rfind(".alpha")]
|
||||
alpha = float(lora_sd[key].detach().numpy())
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
elif "lora_down" in key:
|
||||
lora_module_name = key[:key.rfind(".lora_down")]
|
||||
dim = lora_sd[key].size()[0]
|
||||
dims[lora_module_name] = dim
|
||||
if lora_module_name not in base_dims:
|
||||
base_dims[lora_module_name] = dim
|
||||
|
||||
for lora_module_name in dims.keys():
|
||||
if lora_module_name not in alphas:
|
||||
alpha = dims[lora_module_name]
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
|
||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
|
||||
else:
|
||||
alpha = lora_sd[key].detach().numpy()
|
||||
merged_sd[key] = lora_sd[key]
|
||||
else:
|
||||
continue
|
||||
|
||||
lora_module_name = key[:key.rfind(".lora_")]
|
||||
|
||||
base_alpha = base_alphas[lora_module_name]
|
||||
alpha = alphas[lora_module_name]
|
||||
|
||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key].size() == lora_sd[key].size(
|
||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||
else:
|
||||
if "lora_down" in key:
|
||||
dim = lora_sd[key].size()[0]
|
||||
merged_sd[key] = lora_sd[key] * ratio
|
||||
merged_sd[key] = lora_sd[key] * scale
|
||||
|
||||
print(f"dim (rank): {dim}, alpha: {alpha}")
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
# set alpha to sd
|
||||
for lora_module_name, alpha in base_alphas.items():
|
||||
key = lora_module_name + ".alpha"
|
||||
merged_sd[key] = torch.tensor(alpha)
|
||||
|
||||
return merged_sd, dim, alpha
|
||||
print("merged model")
|
||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
|
||||
return merged_sd
|
||||
|
||||
|
||||
def merge(args):
|
||||
@ -152,7 +185,7 @@ def merge(args):
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
179
networks/merge_lora_old.py
Normal file
179
networks/merge_lora_old.py
Normal file
@ -0,0 +1,179 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
sd = load_file(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
return sd
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
text_encoder.to(merge_dtype)
|
||||
unet.to(merge_dtype)
|
||||
|
||||
# create module map
|
||||
name_to_module = {}
|
||||
for i, root_module in enumerate([text_encoder, unet]):
|
||||
if i == 0:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
||||
|
||||
# find original module for this lora
|
||||
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# print(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
else:
|
||||
# conv2d
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, merge_dtype):
|
||||
merged_sd = {}
|
||||
|
||||
alpha = None
|
||||
dim = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
|
||||
else:
|
||||
alpha = lora_sd[key].detach().numpy()
|
||||
merged_sd[key] = lora_sd[key]
|
||||
else:
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key].size() == lora_sd[key].size(
|
||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
|
||||
else:
|
||||
if "lora_down" in key:
|
||||
dim = lora_sd[key].size()[0]
|
||||
merged_sd[key] = lora_sd[key] * ratio
|
||||
|
||||
print(f"dim (rank): {dim}, alpha: {alpha}")
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
|
||||
return merged_sd, dim, alpha
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
if args.sd_model is not None:
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
|
||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
||||
parser.add_argument("--precision", type=str, default="float",
|
||||
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
||||
parser.add_argument("--sd_model", type=str, default=None,
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--models", type=str, nargs='*',
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--ratios", type=float, nargs='*',
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
@ -5,37 +5,40 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
with safe_open(file_name, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
metadata = None
|
||||
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
return sd
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(model, file_name, metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
|
||||
def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
||||
print("Loading Model...")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
|
||||
@ -55,7 +58,7 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
||||
scale = network_alpha/network_dim
|
||||
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
||||
|
||||
print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}")
|
||||
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
|
||||
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
@ -84,7 +87,7 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
||||
lora_down_weight = lora_down_weight.squeeze()
|
||||
lora_up_weight = lora_up_weight.squeeze()
|
||||
|
||||
if args.device:
|
||||
if device:
|
||||
org_device = lora_up_weight.device
|
||||
lora_up_weight = lora_up_weight.to(args.device)
|
||||
lora_down_weight = lora_down_weight.to(args.device)
|
||||
@ -125,7 +128,8 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
||||
weights_loaded = False
|
||||
|
||||
print("resizing complete")
|
||||
return o_lora_sd
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
|
||||
|
||||
def resize(args):
|
||||
|
||||
@ -143,10 +147,27 @@ def resize(args):
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype)
|
||||
print("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||
|
||||
print("resizing rank...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device)
|
||||
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||
metadata["ss_network_dim"] = str(args.new_rank)
|
||||
metadata["ss_network_alpha"] = str(new_alpha)
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
164
networks/svd_merge_lora.py
Normal file
164
networks/svd_merge_lora.py
Normal file
@ -0,0 +1,164 @@
|
||||
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
sd = load_file(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
return sd
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
||||
merged_sd = {}
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
for key in tqdm(list(lora_sd.keys())):
|
||||
if 'lora_down' not in key:
|
||||
continue
|
||||
|
||||
lora_module_name = key[:key.rfind(".lora_down")]
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
network_dim = down_weight.size()[0]
|
||||
|
||||
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
|
||||
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
|
||||
|
||||
in_dim = down_weight.size()[1]
|
||||
out_dim = up_weight.size()[0]
|
||||
conv2d = len(down_weight.size()) == 4
|
||||
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
|
||||
|
||||
# make original weight if not exist
|
||||
if lora_module_name not in merged_sd:
|
||||
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
else:
|
||||
weight = merged_sd[lora_module_name]
|
||||
|
||||
# merge to weight
|
||||
if device:
|
||||
up_weight = up_weight.to(device)
|
||||
down_weight = down_weight.to(device)
|
||||
|
||||
# W <- W + U * D
|
||||
scale = (alpha / network_dim)
|
||||
if not conv2d: # linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
else:
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||
).unsqueeze(2).unsqueeze(3) * scale
|
||||
|
||||
merged_sd[lora_module_name] = weight
|
||||
|
||||
# extract from merged weights
|
||||
print("extract new lora...")
|
||||
merged_lora_sd = {}
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||
conv2d = (len(mat.size()) == 4)
|
||||
if conv2d:
|
||||
mat = mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
U = U[:, :new_rank]
|
||||
S = S[:new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
|
||||
Vh = Vh[:new_rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
up_weight = U
|
||||
down_weight = Vh
|
||||
|
||||
if conv2d:
|
||||
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
||||
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
|
||||
|
||||
return merged_lora_sd
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
||||
parser.add_argument("--precision", type=str, default="float",
|
||||
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--models", type=str, nargs='*',
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--ratios", type=float, nargs='*',
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument("--new_rank", type=int, default=4,
|
||||
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
@ -481,40 +481,6 @@ def train_model(
|
||||
save_inference_file(output_dir, v2, v_parameterization, output_name)
|
||||
|
||||
|
||||
def UI(username, password):
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Dreambooth TI'):
|
||||
(
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = ti_tab()
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
enable_copy_info_button=True,
|
||||
)
|
||||
|
||||
# Show the interface
|
||||
if not username == '':
|
||||
interface.launch(auth=(username, password))
|
||||
else:
|
||||
interface.launch()
|
||||
|
||||
|
||||
def ti_tab(
|
||||
train_data_dir=gr.Textbox(),
|
||||
reg_data_dir=gr.Textbox(),
|
||||
@ -823,6 +789,45 @@ def ti_tab(
|
||||
)
|
||||
|
||||
|
||||
def UI(**kwargs):
|
||||
css = ''
|
||||
|
||||
if os.path.exists('./style.css'):
|
||||
with open(os.path.join('./style.css'), 'r', encoding='utf8') as file:
|
||||
print('Load CSS...')
|
||||
css += file.read() + '\n'
|
||||
|
||||
interface = gr.Blocks(css=css)
|
||||
|
||||
with interface:
|
||||
with gr.Tab('Dreambooth TI'):
|
||||
(
|
||||
train_data_dir_input,
|
||||
reg_data_dir_input,
|
||||
output_dir_input,
|
||||
logging_dir_input,
|
||||
) = ti_tab()
|
||||
with gr.Tab('Utilities'):
|
||||
utilities_tab(
|
||||
train_data_dir_input=train_data_dir_input,
|
||||
reg_data_dir_input=reg_data_dir_input,
|
||||
output_dir_input=output_dir_input,
|
||||
logging_dir_input=logging_dir_input,
|
||||
enable_copy_info_button=True,
|
||||
)
|
||||
|
||||
# Show the interface
|
||||
launch_kwargs={}
|
||||
if not kwargs.get('username', None) == '':
|
||||
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
|
||||
if kwargs.get('server_port', 0) > 0:
|
||||
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
|
||||
if kwargs.get('inbrowser', False):
|
||||
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
|
||||
print(launch_kwargs)
|
||||
interface.launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -832,7 +837,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
'--password', type=str, default='', help='Password for authentication'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
||||
)
|
||||
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
UI(username=args.username, password=args.password)
|
||||
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
||||
|
121
tools/crop_images_to_n_buckets.py
Normal file
121
tools/crop_images_to_n_buckets.py
Normal file
@ -0,0 +1,121 @@
|
||||
# This code sorts a collection of images in a given directory by their aspect ratio, groups
|
||||
# them into batches of a given size, crops each image in a batch to the average aspect ratio
|
||||
# of that batch, and saves the cropped images in a specified directory. The user provides
|
||||
# the paths to the input directory and the output directory, as well as the desired batch
|
||||
# size. The program drops any images that do not fit exactly into the batches.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import argparse
|
||||
|
||||
def aspect_ratio(img_path):
|
||||
"""Return aspect ratio of an image"""
|
||||
image = cv2.imread(img_path)
|
||||
height, width = image.shape[:2]
|
||||
aspect_ratio = float(width) / float(height)
|
||||
return aspect_ratio
|
||||
|
||||
def sort_images_by_aspect_ratio(path):
|
||||
"""Sort all images in a folder by aspect ratio"""
|
||||
images = []
|
||||
for filename in os.listdir(path):
|
||||
if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):
|
||||
img_path = os.path.join(path, filename)
|
||||
images.append((img_path, aspect_ratio(img_path)))
|
||||
# sort the list of tuples based on the aspect ratio
|
||||
sorted_images = sorted(images, key=lambda x: x[1])
|
||||
return sorted_images
|
||||
|
||||
def create_groups(sorted_images, n_groups):
|
||||
"""Create n groups from sorted list of images"""
|
||||
n = len(sorted_images)
|
||||
size = n // n_groups
|
||||
groups = [sorted_images[i * size : (i + 1) * size] for i in range(n_groups - 1)]
|
||||
groups.append(sorted_images[(n_groups - 1) * size:])
|
||||
return groups
|
||||
|
||||
def average_aspect_ratio(group):
|
||||
"""Calculate average aspect ratio for a group"""
|
||||
aspect_ratios = [aspect_ratio for _, aspect_ratio in group]
|
||||
avg_aspect_ratio = sum(aspect_ratios) / len(aspect_ratios)
|
||||
return avg_aspect_ratio
|
||||
|
||||
def center_crop_image(image, target_aspect_ratio):
|
||||
height, width = image.shape[:2]
|
||||
current_aspect_ratio = float(width) / float(height)
|
||||
|
||||
if current_aspect_ratio == target_aspect_ratio:
|
||||
return image
|
||||
|
||||
if current_aspect_ratio > target_aspect_ratio:
|
||||
new_width = int(target_aspect_ratio * height)
|
||||
x_start = (width - new_width) // 2
|
||||
cropped_image = image[:, x_start:x_start+new_width]
|
||||
else:
|
||||
new_height = int(width / target_aspect_ratio)
|
||||
y_start = (height - new_height) // 2
|
||||
cropped_image = image[y_start:y_start+new_height, :]
|
||||
|
||||
return cropped_image
|
||||
|
||||
def save_cropped_images(group, folder_name, group_number, avg_aspect_ratio):
|
||||
if not os.path.exists(folder_name):
|
||||
os.makedirs(folder_name)
|
||||
|
||||
# get the smallest size of the images
|
||||
small_height = 0
|
||||
small_width = 0
|
||||
smallest_res = 100000000
|
||||
for i, image in enumerate(group):
|
||||
img_path, aspect_ratio = image
|
||||
image = cv2.imread(img_path)
|
||||
cropped_image = center_crop_image(image, avg_aspect_ratio)
|
||||
height, width = cropped_image.shape[:2]
|
||||
if smallest_res > height * width:
|
||||
small_height = height
|
||||
small_width = width
|
||||
smallest_res = height * width
|
||||
|
||||
# resize all images to the smallest resolution of the images in the group
|
||||
for i, image in enumerate(group):
|
||||
img_path, aspect_ratio = image
|
||||
image = cv2.imread(img_path)
|
||||
cropped_image = center_crop_image(image, avg_aspect_ratio)
|
||||
resized_image = cv2.resize(cropped_image, (small_width, small_height))
|
||||
save_path = os.path.join(folder_name, "group_{}_{}.jpg".format(group_number, i))
|
||||
cv2.imwrite(save_path, resized_image)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Sort images and crop them based on aspect ratio')
|
||||
parser.add_argument('--path', type=str, help='Path to the directory containing images', required=True)
|
||||
parser.add_argument('--dst_path', type=str, help='Path to the directory to save the cropped images', required=True)
|
||||
parser.add_argument('--batch_size', type=int, help='Size of the batches to create', required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
sorted_images = sort_images_by_aspect_ratio(args.path)
|
||||
total_images = len(sorted_images)
|
||||
print(f'Total images: {total_images}')
|
||||
|
||||
group_size = total_images // args.batch_size
|
||||
|
||||
print(f'Train batch size: {args.batch_size}, image group size: {group_size}')
|
||||
remainder = total_images % args.batch_size
|
||||
|
||||
if remainder != 0:
|
||||
print(f'Dropping {remainder} images that do not fit in groups...')
|
||||
sorted_images = sorted_images[:-remainder]
|
||||
total_images = len(sorted_images)
|
||||
group_size = total_images // args.batch_size
|
||||
|
||||
print('Creating groups...')
|
||||
groups = create_groups(sorted_images, group_size)
|
||||
|
||||
print('Saving cropped and resize images...')
|
||||
for i, group in enumerate(groups):
|
||||
avg_aspect_ratio = average_aspect_ratio(group)
|
||||
save_cropped_images(group, args.dst_path, i+1, avg_aspect_ratio)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -4,6 +4,8 @@ import cv2
|
||||
import argparse
|
||||
import shutil
|
||||
import math
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
|
||||
@ -35,7 +37,11 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
continue
|
||||
|
||||
# Load image
|
||||
img = cv2.imread(os.path.join(src_img_folder, filename))
|
||||
# img = cv2.imread(os.path.join(src_img_folder, filename))
|
||||
image = Image.open(os.path.join(src_img_folder, filename))
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
img = np.array(image, np.uint8)
|
||||
|
||||
base, _ = os.path.splitext(filename)
|
||||
for max_resolution in max_resolutions:
|
||||
@ -72,7 +78,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
|
||||
|
||||
# Save resized image in dst_img_folder
|
||||
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
||||
# cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
||||
image = Image.fromarray(img)
|
||||
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
|
||||
|
||||
proc = "Resized" if current_pixels > max_pixels else "Saved"
|
||||
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||
|
||||
|
@ -1,78 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import argparse
|
||||
import shutil
|
||||
import math
|
||||
|
||||
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=1, caption_extension=''):
|
||||
# Split the max_resolution string by "," and strip any whitespaces
|
||||
max_resolutions = [res.strip() for res in max_resolution.split(',')]
|
||||
|
||||
# Create destination folder if it does not exist
|
||||
if not os.path.exists(dst_img_folder):
|
||||
os.makedirs(dst_img_folder)
|
||||
|
||||
# Iterate through all files in src_img_folder
|
||||
for filename in os.listdir(src_img_folder):
|
||||
# Check if the image is png, jpg or webp
|
||||
if not filename.endswith(('.png', '.jpg', '.webp')):
|
||||
# Copy the file to the destination folder if not png, jpg or webp
|
||||
# shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
||||
continue
|
||||
|
||||
# Load image
|
||||
img = cv2.imread(os.path.join(src_img_folder, filename))
|
||||
|
||||
for max_resolution in max_resolutions:
|
||||
# Calculate max_pixels from max_resolution string
|
||||
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||
|
||||
# Calculate current number of pixels
|
||||
current_pixels = img.shape[0] * img.shape[1]
|
||||
|
||||
# Check if the image needs resizing
|
||||
if current_pixels > max_pixels:
|
||||
# Calculate scaling factor
|
||||
scale_factor = max_pixels / current_pixels
|
||||
|
||||
# Calculate new dimensions
|
||||
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||
|
||||
# Resize image using area interpolation (best when downsampling)
|
||||
img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
# Calculate the new height and width that are divisible by divisible_by
|
||||
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
||||
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
||||
|
||||
# Center crop the image to the calculated dimensions
|
||||
y = int((img.shape[0] - new_height) / 2)
|
||||
x = int((img.shape[1] - new_width) / 2)
|
||||
img = img[y:y + new_height, x:x + new_width]
|
||||
|
||||
# Split filename into base and extension
|
||||
base, ext = os.path.splitext(filename)
|
||||
new_filename = base + '+' + max_resolution + '.jpg'
|
||||
|
||||
# copy caption file with right name if one exist
|
||||
if os.path.exists(os.path.join(src_img_folder, base + caption_extension)):
|
||||
shutil.copy(os.path.join(src_img_folder, base + caption_extension), os.path.join(dst_img_folder, new_filename + caption_extension))
|
||||
|
||||
# Save resized image in dst_img_folder
|
||||
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
||||
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Resize images in a folder to a specified max resolution(s)')
|
||||
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images')
|
||||
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
|
||||
parser.add_argument('--max_resolution', type=str, help='Maximum resolution(s) in the format "512x512,448x448,384x384, etc, etc"', default="512x512,448x448,384x384")
|
||||
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=1)
|
||||
parser.add_argument('--caption_extension', type=str, help='Extension of caption files to copy with resized images"', default=".txt")
|
||||
args = parser.parse_args()
|
||||
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, args.divisible_by, args.caption_extension)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -98,12 +98,12 @@ def train(args):
|
||||
|
||||
# Convert the init_word to token_id
|
||||
if args.init_word is not None:
|
||||
init_token_id = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||
assert len(
|
||||
init_token_id) == 1, f"init word {args.init_word} is not converted to single token / 初期化単語が二つ以上のトークンに変換されます。別の単語を使ってください"
|
||||
init_token_id = init_token_id[0]
|
||||
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
||||
print(
|
||||
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}")
|
||||
else:
|
||||
init_token_id = None
|
||||
init_token_ids = None
|
||||
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
@ -120,9 +120,9 @@ def train(args):
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
if init_token_id is not None:
|
||||
for token_id in token_ids:
|
||||
token_embeds[token_id] = token_embeds[init_token_id]
|
||||
if init_token_ids is not None:
|
||||
for i, token_id in enumerate(token_ids):
|
||||
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
|
||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
|
||||
# load weights
|
||||
@ -492,7 +492,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--token_string", type=str, default=None,
|
||||
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること")
|
||||
parser.add_argument("--init_word", type=str, default=None,
|
||||
help="word to initialize vector / ベクトルを初期化に使用する単語、tokenizerで一語になること")
|
||||
help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
||||
parser.add_argument("--use_object_template", action='store_true',
|
||||
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する")
|
||||
parser.add_argument("--use_style_template", action='store_true',
|
||||
|
Loading…
Reference in New Issue
Block a user