commit
9a7bb4c624
@ -143,6 +143,14 @@ Then redo the installation instruction within the kohya_ss venv.
|
|||||||
|
|
||||||
## Change history
|
## Change history
|
||||||
|
|
||||||
|
* 2023/02/09 (v20.7.1)
|
||||||
|
- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
|
||||||
|
- ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).
|
||||||
|
- ``--caption_dropout_every_n_epochs`` option specifies how many epochs to drop captions. If ``3`` is specified, in epoch 3, 6, 9 ..., images are trained with all captions empty. Default is None (no dropout).
|
||||||
|
- ``--caption_tag_dropout_rate`` option specified the dropout rate for tags (comma separated tokens) (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the tag is removed from the caption. If ``--keep_tokens`` option is set, these tokens (tags) are not dropped. Default is 0 (no droupout).
|
||||||
|
- The bulk image downsampling script is added. Documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) (in Jpanaese). Thanks to bmaltais!
|
||||||
|
- Typo check is added. Thanks to shirayu!
|
||||||
|
- Add option to autolaunch the GUI in a browser and set the server_port. USe either `gui.ps1 --inbrowser --server_port 3456`or `gui.cmd -inbrowser -server_port 3456`
|
||||||
* 2023/02/06 (v20.7.0)
|
* 2023/02/06 (v20.7.0)
|
||||||
- ``--bucket_reso_steps`` and ``--bucket_no_upscale`` options are added to training scripts (fine tuning, DreamBooth, LoRA and Textual Inversion) and ``prepare_buckets_latents.py``.
|
- ``--bucket_reso_steps`` and ``--bucket_no_upscale`` options are added to training scripts (fine tuning, DreamBooth, LoRA and Textual Inversion) and ``prepare_buckets_latents.py``.
|
||||||
- ``--bucket_reso_steps`` takes the steps for buckets in aspect ratio bucketing. Default is 64, same as before.
|
- ``--bucket_reso_steps`` takes the steps for buckets in aspect ratio bucketing. Default is 64, same as before.
|
||||||
|
@ -88,6 +88,7 @@ def save_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -177,6 +178,7 @@ def open_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -250,6 +252,7 @@ def train_model(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -416,6 +419,8 @@ def train_model(
|
|||||||
bucket_no_upscale=bucket_no_upscale,
|
bucket_no_upscale=bucket_no_upscale,
|
||||||
random_crop=random_crop,
|
random_crop=random_crop,
|
||||||
bucket_reso_steps=bucket_reso_steps,
|
bucket_reso_steps=bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate=caption_dropout_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -627,6 +632,7 @@ def dreambooth_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -695,6 +701,7 @@ def dreambooth_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -36,6 +36,10 @@ def train(args):
|
|||||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||||
args.dataset_repeats, args.debug_dataset)
|
args.dataset_repeats, args.debug_dataset)
|
||||||
|
|
||||||
|
# 学習データのdropout率を設定する
|
||||||
|
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
||||||
|
|
||||||
train_dataset.make_buckets()
|
train_dataset.make_buckets()
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
@ -226,6 +230,8 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
train_dataset.set_current_epoch(epoch + 1)
|
||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
@ -332,7 +338,7 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, False, True)
|
train_util.add_dataset_arguments(parser, False, True, True)
|
||||||
train_util.add_training_arguments(parser, False)
|
train_util.add_training_arguments(parser, False)
|
||||||
train_util.add_sd_saving_arguments(parser)
|
train_util.add_sd_saving_arguments(parser)
|
||||||
|
|
||||||
|
@ -84,6 +84,7 @@ def save_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -179,6 +180,7 @@ def open_config_file(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -259,6 +261,7 @@ def train_model(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
# create caption json file
|
# create caption json file
|
||||||
if generate_caption_database:
|
if generate_caption_database:
|
||||||
@ -405,6 +408,8 @@ def train_model(
|
|||||||
bucket_no_upscale=bucket_no_upscale,
|
bucket_no_upscale=bucket_no_upscale,
|
||||||
random_crop=random_crop,
|
random_crop=random_crop,
|
||||||
bucket_reso_steps=bucket_reso_steps,
|
bucket_reso_steps=bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate=caption_dropout_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -614,6 +619,7 @@ def finetune_tab():
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -678,6 +684,7 @@ def finetune_tab():
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_run.click(train_model, inputs=settings_list)
|
button_run.click(train_model, inputs=settings_list)
|
||||||
|
23
gui.bat
23
gui.bat
@ -1,10 +1,23 @@
|
|||||||
@echo off
|
@echo off
|
||||||
|
|
||||||
set VENV_DIR=.\venv
|
REM Use this batch file with the following options:
|
||||||
set PYTHON=python
|
REM -inbrowser - To launch the program in the browser
|
||||||
|
REM -server_port [port number] - To specify the server port
|
||||||
|
|
||||||
call %VENV_DIR%\Scripts\activate.bat
|
set inbrowserOption=
|
||||||
|
set serverPortOption=
|
||||||
|
|
||||||
%PYTHON% kohya_gui.py
|
if "%1" == "-server_port" (
|
||||||
|
set serverPortOption=--server_port %2
|
||||||
|
if "%3" == "-inbrowser" (
|
||||||
|
set inbrowserOption=--inbrowser
|
||||||
|
)
|
||||||
|
) else if "%1" == "-inbrowser" (
|
||||||
|
set inbrowserOption=--inbrowser
|
||||||
|
if "%2" == "-server_port" (
|
||||||
|
set serverPortOption=--server_port %3
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
pause
|
call .\venv\Scripts\activate.bat
|
||||||
|
python.exe kohya_gui.py %inbrowserOption% %serverPortOption%
|
||||||
|
11
gui.ps1
11
gui.ps1
@ -1,2 +1,11 @@
|
|||||||
|
# Example command: .\gui.ps1 -server_port 8000 -inbrowser
|
||||||
|
|
||||||
|
param([string]$username="", [string]$password="", [switch]$inbrowser, [int]$server_port)
|
||||||
.\venv\Scripts\activate
|
.\venv\Scripts\activate
|
||||||
python.exe kohya_gui.py
|
|
||||||
|
if ($server_port -le 0 -and $inbrowser -eq $false) {
|
||||||
|
Write-Host "Error: You must provide either the --server_port or --inbrowser argument."
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
python.exe kohya_gui.py --username $username --password $password --server_port $server_port --inbrowser
|
20
kohya_gui.py
20
kohya_gui.py
@ -10,7 +10,7 @@ from library.merge_lora_gui import gradio_merge_lora_tab
|
|||||||
from lora_gui import lora_tab
|
from lora_gui import lora_tab
|
||||||
|
|
||||||
|
|
||||||
def UI(username, password):
|
def UI(username, password, inbrowser, server_port):
|
||||||
|
|
||||||
css = ''
|
css = ''
|
||||||
|
|
||||||
@ -47,11 +47,13 @@ def UI(username, password):
|
|||||||
gradio_merge_lora_tab()
|
gradio_merge_lora_tab()
|
||||||
|
|
||||||
# Show the interface
|
# Show the interface
|
||||||
if not username == '':
|
kwargs = {}
|
||||||
interface.launch(auth=(username, password))
|
if username:
|
||||||
else:
|
kwargs["auth"] = (username, password)
|
||||||
interface.launch()
|
if server_port > 0:
|
||||||
|
kwargs["server_port"] = server_port
|
||||||
|
kwargs["inbrowser"] = inbrowser
|
||||||
|
interface.launch(**kwargs)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||||
@ -62,7 +64,11 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--password', type=str, default='', help='Password for authentication'
|
'--password', type=str, default='', help='Password for authentication'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--server_port', type=int, default=0, help='Port to run the server listener on'
|
||||||
|
)
|
||||||
|
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
UI(username=args.username, password=args.password)
|
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
|
||||||
|
@ -563,6 +563,15 @@ def gradio_advanced_training():
|
|||||||
random_crop = gr.Checkbox(
|
random_crop = gr.Checkbox(
|
||||||
label='Random crop instead of center crop', value=False
|
label='Random crop instead of center crop', value=False
|
||||||
)
|
)
|
||||||
|
with gr.Row():
|
||||||
|
caption_dropout_every_n_epochs = gr.Number(
|
||||||
|
label="Dropout caption every n epochs",
|
||||||
|
value=0
|
||||||
|
)
|
||||||
|
caption_dropout_rate = gr.Number(
|
||||||
|
label="Rate of caption dropout",
|
||||||
|
value=0
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||||
resume = gr.Textbox(
|
resume = gr.Textbox(
|
||||||
@ -599,6 +608,7 @@ def gradio_advanced_training():
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -622,6 +632,12 @@ def run_cmd_advanced_training(**kwargs):
|
|||||||
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
|
f' --keep_tokens="{kwargs.get("keep_tokens", "")}"'
|
||||||
if int(kwargs.get('keep_tokens', 0)) > 0
|
if int(kwargs.get('keep_tokens', 0)) > 0
|
||||||
else '',
|
else '',
|
||||||
|
f' --caption_dropout_every_n_epochs="{kwargs.get("caption_dropout_every_n_epochs", "")}"'
|
||||||
|
if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
|
||||||
|
else '',
|
||||||
|
f' --caption_dropout_rate="{kwargs.get("caption_dropout_rate", "")}"'
|
||||||
|
if float(kwargs.get('caption_dropout_rate', 0)) > 0
|
||||||
|
else '',
|
||||||
|
|
||||||
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
|
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
|
||||||
if int(kwargs.get('bucket_reso_steps', 64)) >= 1
|
if int(kwargs.get('bucket_reso_steps', 64)) >= 1
|
||||||
|
@ -113,7 +113,7 @@ class BucketManager():
|
|||||||
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
||||||
self.predefined_resos = resos.copy()
|
self.predefined_resos = resos.copy()
|
||||||
self.predefined_resos_set = set(resos)
|
self.predefined_resos_set = set(resos)
|
||||||
self.predifined_aspect_ratios = np.array([w / h for w, h in resos])
|
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
|
||||||
|
|
||||||
def add_if_new_reso(self, reso):
|
def add_if_new_reso(self, reso):
|
||||||
if reso not in self.reso_to_id:
|
if reso not in self.reso_to_id:
|
||||||
@ -135,7 +135,7 @@ class BucketManager():
|
|||||||
if reso in self.predefined_resos_set:
|
if reso in self.predefined_resos_set:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
ar_errors = self.predifined_aspect_ratios - aspect_ratio
|
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
||||||
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
||||||
reso = self.predefined_resos[predefined_bucket_id]
|
reso = self.predefined_resos[predefined_bucket_id]
|
||||||
|
|
||||||
@ -223,6 +223,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
|
|
||||||
|
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||||
|
self.dropout_rate: float = 0
|
||||||
|
self.dropout_every_n_epochs: int = None
|
||||||
|
|
||||||
# augmentation
|
# augmentation
|
||||||
flip_p = 0.5 if flip_aug else 0.0
|
flip_p = 0.5 if flip_aug else 0.0
|
||||||
if color_aug:
|
if color_aug:
|
||||||
@ -247,6 +251,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.replacements = {}
|
self.replacements = {}
|
||||||
|
|
||||||
|
def set_current_epoch(self, epoch):
|
||||||
|
self.current_epoch = epoch
|
||||||
|
|
||||||
|
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
||||||
|
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.dropout_every_n_epochs = dropout_every_n_epochs
|
||||||
|
self.tag_dropout_rate = tag_dropout_rate
|
||||||
|
|
||||||
def set_tag_frequency(self, dir_name, captions):
|
def set_tag_frequency(self, dir_name, captions):
|
||||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||||
self.tag_frequency[dir_name] = frequency_for_dir
|
self.tag_frequency[dir_name] = frequency_for_dir
|
||||||
@ -264,27 +277,47 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.replacements[str_from] = str_to
|
self.replacements[str_from] = str_to
|
||||||
|
|
||||||
def process_caption(self, caption):
|
def process_caption(self, caption):
|
||||||
if self.shuffle_caption:
|
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
||||||
tokens = caption.strip().split(",")
|
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
||||||
if self.shuffle_keep_tokens is None:
|
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
||||||
random.shuffle(tokens)
|
|
||||||
else:
|
|
||||||
if len(tokens) > self.shuffle_keep_tokens:
|
|
||||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
|
||||||
tokens = tokens[self.shuffle_keep_tokens:]
|
|
||||||
random.shuffle(tokens)
|
|
||||||
tokens = keep_tokens + tokens
|
|
||||||
caption = ",".join(tokens).strip()
|
|
||||||
|
|
||||||
for str_from, str_to in self.replacements.items():
|
if is_drop_out:
|
||||||
if str_from == "":
|
caption = ""
|
||||||
# replace all
|
else:
|
||||||
if type(str_to) == list:
|
if self.shuffle_caption:
|
||||||
caption = random.choice(str_to)
|
def dropout_tags(tokens):
|
||||||
|
if self.tag_dropout_rate <= 0:
|
||||||
|
return tokens
|
||||||
|
l = []
|
||||||
|
for token in tokens:
|
||||||
|
if random.random() >= self.tag_dropout_rate:
|
||||||
|
l.append(token)
|
||||||
|
return l
|
||||||
|
|
||||||
|
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||||
|
if self.shuffle_keep_tokens is None:
|
||||||
|
random.shuffle(tokens)
|
||||||
|
tokens = dropout_tags(tokens)
|
||||||
else:
|
else:
|
||||||
caption = str_to
|
if len(tokens) > self.shuffle_keep_tokens:
|
||||||
else:
|
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||||
caption = caption.replace(str_from, str_to)
|
tokens = tokens[self.shuffle_keep_tokens:]
|
||||||
|
random.shuffle(tokens)
|
||||||
|
tokens = dropout_tags(tokens)
|
||||||
|
|
||||||
|
tokens = keep_tokens + tokens
|
||||||
|
caption = ", ".join(tokens)
|
||||||
|
|
||||||
|
# textual inversion対応
|
||||||
|
for str_from, str_to in self.replacements.items():
|
||||||
|
if str_from == "":
|
||||||
|
# replace all
|
||||||
|
if type(str_to) == list:
|
||||||
|
caption = random.choice(str_to)
|
||||||
|
else:
|
||||||
|
caption = str_to
|
||||||
|
else:
|
||||||
|
caption = caption.replace(str_from, str_to)
|
||||||
|
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
@ -907,6 +940,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
def debug_dataset(train_dataset, show_input_ids=False):
|
def debug_dataset(train_dataset, show_input_ids=False):
|
||||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||||
print("Escape for exit. / Escキーで中断、終了します")
|
print("Escape for exit. / Escキーで中断、終了します")
|
||||||
|
|
||||||
|
train_dataset.set_current_epoch(1)
|
||||||
k = 0
|
k = 0
|
||||||
for i, example in enumerate(train_dataset):
|
for i, example in enumerate(train_dataset):
|
||||||
if example['latents'] is not None:
|
if example['latents'] is not None:
|
||||||
@ -1377,7 +1412,7 @@ def verify_training_args(args: argparse.Namespace):
|
|||||||
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||||
|
|
||||||
|
|
||||||
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool):
|
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
|
||||||
# dataset common
|
# dataset common
|
||||||
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
parser.add_argument("--shuffle_caption", action="store_true",
|
parser.add_argument("--shuffle_caption", action="store_true",
|
||||||
@ -1408,6 +1443,16 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
|||||||
parser.add_argument("--bucket_no_upscale", action="store_true",
|
parser.add_argument("--bucket_no_upscale", action="store_true",
|
||||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
||||||
|
|
||||||
|
if support_caption_dropout:
|
||||||
|
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||||
|
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||||
|
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
||||||
|
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
||||||
|
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
||||||
|
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
||||||
|
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
||||||
|
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
||||||
|
|
||||||
if support_dreambooth:
|
if support_dreambooth:
|
||||||
# DreamBooth dataset
|
# DreamBooth dataset
|
||||||
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
||||||
@ -1718,4 +1763,4 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
|||||||
return (tensor_pil, img_path)
|
return (tensor_pil, img_path)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
13
lora_gui.py
13
lora_gui.py
@ -99,6 +99,7 @@ def save_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -195,6 +196,7 @@ def open_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -275,7 +277,8 @@ def train_model(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
):
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
return
|
return
|
||||||
@ -380,7 +383,7 @@ def train_model(
|
|||||||
|
|
||||||
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"'
|
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"'
|
||||||
|
|
||||||
run_cmd += f' --bucket_reso_steps=1 --bucket_no_upscale' # --random_crop'
|
# run_cmd += f' --caption_dropout_rate="0.1" --caption_dropout_every_n_epochs=1' # --random_crop'
|
||||||
|
|
||||||
if v2:
|
if v2:
|
||||||
run_cmd += ' --v2'
|
run_cmd += ' --v2'
|
||||||
@ -440,7 +443,7 @@ def train_model(
|
|||||||
else:
|
else:
|
||||||
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"'
|
run_cmd += f' --lr_scheduler_num_cycles="{epoch}"'
|
||||||
if not lr_scheduler_power == '':
|
if not lr_scheduler_power == '':
|
||||||
run_cmd += f' --output_name="{lr_scheduler_power}"'
|
run_cmd += f' --lr_scheduler_power="{lr_scheduler_power}"'
|
||||||
|
|
||||||
run_cmd += run_cmd_training(
|
run_cmd += run_cmd_training(
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
@ -476,6 +479,8 @@ def train_model(
|
|||||||
bucket_no_upscale=bucket_no_upscale,
|
bucket_no_upscale=bucket_no_upscale,
|
||||||
random_crop=random_crop,
|
random_crop=random_crop,
|
||||||
bucket_reso_steps=bucket_reso_steps,
|
bucket_reso_steps=bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate=caption_dropout_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(run_cmd)
|
print(run_cmd)
|
||||||
@ -725,6 +730,7 @@ def lora_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -805,6 +811,7 @@ def lora_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from library import train_util
|
from library import train_util
|
||||||
@ -98,7 +99,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
||||||
loras = []
|
loras = []
|
||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
if module.__class__.__name__ in target_replace_modules:
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
|
@ -1,24 +1,26 @@
|
|||||||
accelerate==0.15.0
|
accelerate==0.15.0
|
||||||
transformers==4.26.0
|
transformers==4.26.0
|
||||||
ftfy
|
ftfy==6.1.1
|
||||||
albumentations
|
albumentations==1.3.0
|
||||||
opencv-python
|
opencv-python==4.7.0.68
|
||||||
einops
|
einops==0.6.0
|
||||||
diffusers[torch]==0.10.2
|
diffusers[torch]==0.10.2
|
||||||
pytorch_lightning
|
pytorch-lightning==1.9.0
|
||||||
bitsandbytes==0.35.0
|
bitsandbytes==0.35.0
|
||||||
tensorboard
|
tensorboard==2.10.1
|
||||||
safetensors==0.2.6
|
safetensors==0.2.6
|
||||||
gradio==3.16.2
|
gradio==3.16.2
|
||||||
altair
|
altair==4.2.2
|
||||||
easygui
|
easygui==0.98.3
|
||||||
tk
|
tk==0.1.0
|
||||||
# for BLIP captioning
|
# for BLIP captioning
|
||||||
requests
|
requests==2.28.2
|
||||||
timm
|
timm==0.6.12
|
||||||
fairscale
|
fairscale==0.4.13
|
||||||
# for WD14 captioning
|
# for WD14 captioning
|
||||||
tensorflow<2.11
|
# tensorflow<2.11
|
||||||
huggingface-hub
|
tensorflow==2.10.1
|
||||||
|
huggingface-hub==0.12.0
|
||||||
|
xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||||
# for kohya_ss library
|
# for kohya_ss library
|
||||||
.
|
.
|
@ -94,6 +94,7 @@ def save_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -193,6 +194,7 @@ def open_configuration(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
# Get list of function parameters and values
|
# Get list of function parameters and values
|
||||||
parameters = list(locals().items())
|
parameters = list(locals().items())
|
||||||
@ -272,6 +274,7 @@ def train_model(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
):
|
):
|
||||||
if pretrained_model_name_or_path == '':
|
if pretrained_model_name_or_path == '':
|
||||||
msgbox('Source model information is missing')
|
msgbox('Source model information is missing')
|
||||||
@ -453,6 +456,8 @@ def train_model(
|
|||||||
bucket_no_upscale=bucket_no_upscale,
|
bucket_no_upscale=bucket_no_upscale,
|
||||||
random_crop=random_crop,
|
random_crop=random_crop,
|
||||||
bucket_reso_steps=bucket_reso_steps,
|
bucket_reso_steps=bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
|
||||||
|
caption_dropout_rate=caption_dropout_rate,
|
||||||
)
|
)
|
||||||
run_cmd += f' --token_string="{token_string}"'
|
run_cmd += f' --token_string="{token_string}"'
|
||||||
run_cmd += f' --init_word="{init_word}"'
|
run_cmd += f' --init_word="{init_word}"'
|
||||||
@ -709,6 +714,7 @@ def ti_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
) = gradio_advanced_training()
|
) = gradio_advanced_training()
|
||||||
color_aug.change(
|
color_aug.change(
|
||||||
color_aug_changed,
|
color_aug_changed,
|
||||||
@ -783,6 +789,7 @@ def ti_tab(
|
|||||||
bucket_no_upscale,
|
bucket_no_upscale,
|
||||||
random_crop,
|
random_crop,
|
||||||
bucket_reso_steps,
|
bucket_reso_steps,
|
||||||
|
caption_dropout_every_n_epochs, caption_dropout_rate,
|
||||||
]
|
]
|
||||||
|
|
||||||
button_open_config.click(
|
button_open_config.click(
|
||||||
|
113
tools/resize_images_to_resolution.py
Normal file
113
tools/resize_images_to_resolution.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
|
||||||
|
# Split the max_resolution string by "," and strip any whitespaces
|
||||||
|
max_resolutions = [res.strip() for res in max_resolution.split(',')]
|
||||||
|
|
||||||
|
# # Calculate max_pixels from max_resolution string
|
||||||
|
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||||
|
|
||||||
|
# Create destination folder if it does not exist
|
||||||
|
if not os.path.exists(dst_img_folder):
|
||||||
|
os.makedirs(dst_img_folder)
|
||||||
|
|
||||||
|
# Select interpolation method
|
||||||
|
if interpolation == 'lanczos4':
|
||||||
|
cv2_interpolation = cv2.INTER_LANCZOS4
|
||||||
|
elif interpolation == 'cubic':
|
||||||
|
cv2_interpolation = cv2.INTER_CUBIC
|
||||||
|
else:
|
||||||
|
cv2_interpolation = cv2.INTER_AREA
|
||||||
|
|
||||||
|
# Iterate through all files in src_img_folder
|
||||||
|
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
|
||||||
|
for filename in os.listdir(src_img_folder):
|
||||||
|
# Check if the image is png, jpg or webp etc...
|
||||||
|
if not filename.endswith(img_exts):
|
||||||
|
# Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
|
||||||
|
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Load image
|
||||||
|
img = cv2.imread(os.path.join(src_img_folder, filename))
|
||||||
|
|
||||||
|
base, _ = os.path.splitext(filename)
|
||||||
|
for max_resolution in max_resolutions:
|
||||||
|
# Calculate max_pixels from max_resolution string
|
||||||
|
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||||
|
|
||||||
|
# Calculate current number of pixels
|
||||||
|
current_pixels = img.shape[0] * img.shape[1]
|
||||||
|
|
||||||
|
# Check if the image needs resizing
|
||||||
|
if current_pixels > max_pixels:
|
||||||
|
# Calculate scaling factor
|
||||||
|
scale_factor = max_pixels / current_pixels
|
||||||
|
|
||||||
|
# Calculate new dimensions
|
||||||
|
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||||
|
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||||
|
|
||||||
|
# Resize image
|
||||||
|
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
||||||
|
else:
|
||||||
|
new_height, new_width = img.shape[0:2]
|
||||||
|
|
||||||
|
# Calculate the new height and width that are divisible by divisible_by (with/without resizing)
|
||||||
|
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
||||||
|
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
||||||
|
|
||||||
|
# Center crop the image to the calculated dimensions
|
||||||
|
y = int((img.shape[0] - new_height) / 2)
|
||||||
|
x = int((img.shape[1] - new_width) / 2)
|
||||||
|
img = img[y:y + new_height, x:x + new_width]
|
||||||
|
|
||||||
|
# Split filename into base and extension
|
||||||
|
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
|
||||||
|
|
||||||
|
# Save resized image in dst_img_folder
|
||||||
|
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
||||||
|
proc = "Resized" if current_pixels > max_pixels else "Saved"
|
||||||
|
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||||
|
|
||||||
|
# If other files with same basename, copy them with resolution suffix
|
||||||
|
if copy_associated_files:
|
||||||
|
asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
|
||||||
|
for asoc_file in asoc_files:
|
||||||
|
ext = os.path.splitext(asoc_file)[1]
|
||||||
|
if ext in img_exts:
|
||||||
|
continue
|
||||||
|
for max_resolution in max_resolutions:
|
||||||
|
new_asoc_file = base + '+' + max_resolution + ext
|
||||||
|
print(f"Copy {asoc_file} as {new_asoc_file}")
|
||||||
|
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
|
||||||
|
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
|
||||||
|
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
|
||||||
|
parser.add_argument('--max_resolution', type=str,
|
||||||
|
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
|
||||||
|
parser.add_argument('--divisible_by', type=int,
|
||||||
|
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
|
||||||
|
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
|
||||||
|
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
|
||||||
|
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
|
||||||
|
parser.add_argument('--copy_associated_files', action='store_true',
|
||||||
|
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
|
||||||
|
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -4,13 +4,10 @@ import argparse
|
|||||||
import shutil
|
import shutil
|
||||||
import math
|
import math
|
||||||
|
|
||||||
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2):
|
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=1, caption_extension=''):
|
||||||
# Split the max_resolution string by "," and strip any whitespaces
|
# Split the max_resolution string by "," and strip any whitespaces
|
||||||
max_resolutions = [res.strip() for res in max_resolution.split(',')]
|
max_resolutions = [res.strip() for res in max_resolution.split(',')]
|
||||||
|
|
||||||
# # Calculate max_pixels from max_resolution string
|
|
||||||
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
|
||||||
|
|
||||||
# Create destination folder if it does not exist
|
# Create destination folder if it does not exist
|
||||||
if not os.path.exists(dst_img_folder):
|
if not os.path.exists(dst_img_folder):
|
||||||
os.makedirs(dst_img_folder)
|
os.makedirs(dst_img_folder)
|
||||||
@ -20,7 +17,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
|||||||
# Check if the image is png, jpg or webp
|
# Check if the image is png, jpg or webp
|
||||||
if not filename.endswith(('.png', '.jpg', '.webp')):
|
if not filename.endswith(('.png', '.jpg', '.webp')):
|
||||||
# Copy the file to the destination folder if not png, jpg or webp
|
# Copy the file to the destination folder if not png, jpg or webp
|
||||||
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
# shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Load image
|
# Load image
|
||||||
@ -42,8 +39,8 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
|||||||
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||||
|
|
||||||
# Resize image
|
# Resize image using area interpolation (best when downsampling)
|
||||||
img = cv2.resize(img, (new_width, new_height))
|
img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
# Calculate the new height and width that are divisible by divisible_by
|
# Calculate the new height and width that are divisible by divisible_by
|
||||||
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
||||||
@ -57,7 +54,11 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
|||||||
# Split filename into base and extension
|
# Split filename into base and extension
|
||||||
base, ext = os.path.splitext(filename)
|
base, ext = os.path.splitext(filename)
|
||||||
new_filename = base + '+' + max_resolution + '.jpg'
|
new_filename = base + '+' + max_resolution + '.jpg'
|
||||||
|
|
||||||
|
# copy caption file with right name if one exist
|
||||||
|
if os.path.exists(os.path.join(src_img_folder, base + caption_extension)):
|
||||||
|
shutil.copy(os.path.join(src_img_folder, base + caption_extension), os.path.join(dst_img_folder, new_filename + caption_extension))
|
||||||
|
|
||||||
# Save resized image in dst_img_folder
|
# Save resized image in dst_img_folder
|
||||||
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
||||||
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
print(f"Resized image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||||
@ -69,8 +70,9 @@ def main():
|
|||||||
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
|
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images')
|
||||||
parser.add_argument('--max_resolution', type=str, help='Maximum resolution(s) in the format "512x512,448x448,384x384, etc, etc"', default="512x512,448x448,384x384")
|
parser.add_argument('--max_resolution', type=str, help='Maximum resolution(s) in the format "512x512,448x448,384x384, etc, etc"', default="512x512,448x448,384x384")
|
||||||
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=1)
|
parser.add_argument('--divisible_by', type=int, help='Ensure new dimensions are divisible by this value', default=1)
|
||||||
|
parser.add_argument('--caption_extension', type=str, help='Extension of caption files to copy with resized images"', default=".txt")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution)
|
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution, args.divisible_by, args.caption_extension)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
@ -38,8 +38,13 @@ def train(args):
|
|||||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||||
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
||||||
|
|
||||||
if args.no_token_padding:
|
if args.no_token_padding:
|
||||||
train_dataset.disable_token_padding()
|
train_dataset.disable_token_padding()
|
||||||
|
|
||||||
|
# 学習データのdropout率を設定する
|
||||||
|
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
||||||
|
|
||||||
train_dataset.make_buckets()
|
train_dataset.make_buckets()
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
@ -203,6 +208,7 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
train_dataset.set_current_epoch(epoch + 1)
|
||||||
|
|
||||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||||
unet.train()
|
unet.train()
|
||||||
@ -327,7 +333,7 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, True, False)
|
train_util.add_dataset_arguments(parser, True, False, True)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
train_util.add_sd_saving_arguments(parser)
|
train_util.add_sd_saving_arguments(parser)
|
||||||
|
|
||||||
|
@ -120,18 +120,22 @@ def train(args):
|
|||||||
print("Use DreamBooth method.")
|
print("Use DreamBooth method.")
|
||||||
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
||||||
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
||||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||||
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
|
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
|
||||||
args.random_crop, args.debug_dataset)
|
args.random_crop, args.debug_dataset)
|
||||||
else:
|
else:
|
||||||
print("Train with captions.")
|
print("Train with captions.")
|
||||||
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||||
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||||
args.dataset_repeats, args.debug_dataset)
|
args.dataset_repeats, args.debug_dataset)
|
||||||
|
|
||||||
|
# 学習データのdropout率を設定する
|
||||||
|
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
||||||
|
|
||||||
train_dataset.make_buckets()
|
train_dataset.make_buckets()
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
@ -376,6 +380,8 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
train_dataset.set_current_epoch(epoch + 1)
|
||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch+1)
|
metadata["ss_epoch"] = str(epoch+1)
|
||||||
|
|
||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
@ -509,7 +515,7 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, True, True)
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
|
|
||||||
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||||
|
@ -235,7 +235,7 @@ def train(args):
|
|||||||
text_encoder, optimizer, train_dataloader, lr_scheduler)
|
text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||||
print(len(index_no_updates), torch.sum(index_no_updates))
|
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
|
||||||
# Freeze all parameters except for the token embeddings in text encoder
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
@ -296,6 +296,7 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
train_dataset.set_current_epoch(epoch + 1)
|
||||||
|
|
||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
@ -383,8 +384,8 @@ def train(args):
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||||
d = updated_embs - bef_epo_embs
|
# d = updated_embs - bef_epo_embs
|
||||||
print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
||||||
|
|
||||||
if args.save_every_n_epochs is not None:
|
if args.save_every_n_epochs is not None:
|
||||||
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||||
@ -478,7 +479,7 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, True, True)
|
train_util.add_dataset_arguments(parser, True, True, False)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
|
|
||||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||||
|
Loading…
Reference in New Issue
Block a user