2023/02/09 (v20.7.1)

- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
        - ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).
        - ``--caption_dropout_every_n_epochs`` option specifies how many epochs to drop captions. If ``3`` is specified, in epoch 3, 6, 9 ..., images are trained with all captions empty. Default is None (no dropout).
        - ``--caption_tag_dropout_rate`` option specified the dropout rate for tags (comma separated tokens) (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the tag is removed from the caption. If ``--keep_tokens`` option is set, these tokens (tags) are not dropped. Default is 0 (no droupout).
        - The bulk image downsampling script is added. Documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) (in Jpanaese). Thanks to bmaltais!
        - Typo check is added. Thanks to shirayu!
    - Add option to autolaunch the GUI in a browser and set the server_port. USe either `gui.ps1 --inbrowser --server_port 3456`or `gui.cmd -inbrowser -server_port 3456`
This commit is contained in:
bmaltais 2023-02-09 19:17:24 -05:00
parent 90c0d55457
commit 7bc93821a0
10 changed files with 230 additions and 74 deletions

View File

@ -38,7 +38,7 @@ def train(args):
args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
@ -230,8 +230,7 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.epoch_current = epoch + 1
train_dataset.set_current_epoch(epoch + 1)
for m in training_models:
m.train()

23
gui.bat
View File

@ -1,10 +1,23 @@
@echo off
set VENV_DIR=.\venv
set PYTHON=python
REM Use this batch file with the following options:
REM -inbrowser - To launch the program in the browser
REM -server_port [port number] - To specify the server port
call %VENV_DIR%\Scripts\activate.bat
set inbrowserOption=
set serverPortOption=
%PYTHON% kohya_gui.py
if "%1" == "-server_port" (
set serverPortOption=--server_port %2
if "%3" == "-inbrowser" (
set inbrowserOption=--inbrowser
)
) else if "%1" == "-inbrowser" (
set inbrowserOption=--inbrowser
if "%2" == "-server_port" (
set serverPortOption=--server_port %3
)
)
pause
call .\venv\Scripts\activate.bat
python.exe kohya_gui.py %inbrowserOption% %serverPortOption%

11
gui.ps1
View File

@ -1,2 +1,11 @@
# Example command: .\gui.ps1 -server_port 8000 -inbrowser
param([string]$username="", [string]$password="", [switch]$inbrowser, [int]$server_port)
.\venv\Scripts\activate
python.exe kohya_gui.py
if ($server_port -le 0 -and $inbrowser -eq $false) {
Write-Host "Error: You must provide either the --server_port or --inbrowser argument."
exit 1
}
python.exe kohya_gui.py --username $username --password $password --server_port $server_port --inbrowser

View File

@ -10,7 +10,7 @@ from library.merge_lora_gui import gradio_merge_lora_tab
from lora_gui import lora_tab
def UI(username, password):
def UI(username, password, inbrowser, server_port):
css = ''
@ -47,11 +47,13 @@ def UI(username, password):
gradio_merge_lora_tab()
# Show the interface
if not username == '':
interface.launch(auth=(username, password))
else:
interface.launch()
kwargs = {}
if username:
kwargs["auth"] = (username, password)
if server_port > 0:
kwargs["server_port"] = server_port
kwargs["inbrowser"] = inbrowser
interface.launch(**kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
@ -62,7 +64,11 @@ if __name__ == '__main__':
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password)
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)

View File

@ -223,8 +223,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
# TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう
self.epoch_current: int = int(0)
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None
@ -252,11 +251,14 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {}
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs):
# 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく
def set_current_epoch(self, epoch):
self.current_epoch = epoch
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
# コンストラクタで渡さないのはTextual Inversionで意識したくないからということにしておく
self.dropout_rate = dropout_rate
self.dropout_every_n_epochs = dropout_every_n_epochs
self.tag_dropout_rate = tag_dropout_rate
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
@ -275,27 +277,47 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements[str_from] = str_to
def process_caption(self, caption):
if self.shuffle_caption:
tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None:
random.shuffle(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
tokens = keep_tokens + tokens
caption = ", ".join(tokens)
# dropoutの決定tag dropがこのメソッド内にあるのでここで行うのが良い
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
if is_drop_out:
caption = ""
else:
if self.shuffle_caption:
def dropout_tags(tokens):
if self.tag_dropout_rate <= 0:
return tokens
l = []
for token in tokens:
if random.random() >= self.tag_dropout_rate:
l.append(token)
return l
tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens
caption = ", ".join(tokens)
# textual inversion対応
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
return caption
@ -609,18 +631,7 @@ class BaseDataset(torch.utils.data.Dataset):
images.append(image)
latents_list.append(latents)
# dropoutの決定
is_drop_out = False
if self.dropout_rate > 0 and random.random() < self.dropout_rate:
is_drop_out = True
if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0:
is_drop_out = True
if is_drop_out:
caption = ""
print(f"Drop caption out: {self.process_caption(image_info.caption)}")
else:
caption = self.process_caption(image_info.caption)
caption = self.process_caption(image_info.caption)
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption))
@ -929,6 +940,8 @@ class FineTuningDataset(BaseDataset):
def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
train_dataset.set_current_epoch(1)
k = 0
for i, example in enumerate(train_dataset):
if example['latents'] is not None:
@ -1437,6 +1450,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
if support_dreambooth:
# DreamBooth dataset

View File

@ -1,24 +1,26 @@
accelerate==0.15.0
transformers==4.26.0
ftfy
albumentations
opencv-python
einops
ftfy==6.1.1
albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
diffusers[torch]==0.10.2
pytorch_lightning
pytorch-lightning==1.9.0
bitsandbytes==0.35.0
tensorboard
tensorboard==2.10.1
safetensors==0.2.6
gradio==3.16.2
altair
easygui
tk
altair==4.2.2
easygui==0.98.3
tk==0.1.0
# for BLIP captioning
requests
timm
fairscale
requests==2.28.2
timm==0.6.12
fairscale==0.4.13
# for WD14 captioning
tensorflow<2.11
huggingface-hub
# tensorflow<2.11
tensorflow==2.10.1
huggingface-hub==0.12.0
xformers @ https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
# for kohya_ss library
.

View File

@ -0,0 +1,113 @@
import glob
import os
import cv2
import argparse
import shutil
import math
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
# Split the max_resolution string by "," and strip any whitespaces
max_resolutions = [res.strip() for res in max_resolution.split(',')]
# # Calculate max_pixels from max_resolution string
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Create destination folder if it does not exist
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Select interpolation method
if interpolation == 'lanczos4':
cv2_interpolation = cv2.INTER_LANCZOS4
elif interpolation == 'cubic':
cv2_interpolation = cv2.INTER_CUBIC
else:
cv2_interpolation = cv2.INTER_AREA
# Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
for filename in os.listdir(src_img_folder):
# Check if the image is png, jpg or webp etc...
if not filename.endswith(img_exts):
# Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
continue
# Load image
img = cv2.imread(os.path.join(src_img_folder, filename))
base, _ = os.path.splitext(filename)
for max_resolution in max_resolutions:
# Calculate max_pixels from max_resolution string
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Calculate current number of pixels
current_pixels = img.shape[0] * img.shape[1]
# Check if the image needs resizing
if current_pixels > max_pixels:
# Calculate scaling factor
scale_factor = max_pixels / current_pixels
# Calculate new dimensions
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
new_height, new_width = img.shape[0:2]
# Calculate the new height and width that are divisible by divisible_by (with/without resizing)
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
# Center crop the image to the calculated dimensions
y = int((img.shape[0] - new_height) / 2)
x = int((img.shape[1] - new_width) / 2)
img = img[y:y + new_height, x:x + new_width]
# Split filename into base and extension
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
# Save resized image in dst_img_folder
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
# If other files with same basename, copy them with resolution suffix
if copy_associated_files:
asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
for asoc_file in asoc_files:
ext = os.path.splitext(asoc_file)[1]
if ext in img_exts:
continue
for max_resolution in max_resolutions:
new_asoc_file = base + '+' + max_resolution + ext
print(f"Copy {asoc_file} as {new_asoc_file}")
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
def main():
parser = argparse.ArgumentParser(
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
parser.add_argument('--max_resolution', type=str,
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
if __name__ == '__main__':
main()

View File

@ -43,7 +43,7 @@ def train(args):
train_dataset.disable_token_padding()
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
@ -208,8 +208,7 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.epoch_current = epoch + 1
train_dataset.set_current_epoch(epoch + 1)
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()

View File

@ -134,7 +134,7 @@ def train(args):
args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs)
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
@ -380,8 +380,7 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.epoch_current = epoch + 1
train_dataset.set_current_epoch(epoch + 1)
metadata["ss_epoch"] = str(epoch+1)

View File

@ -235,7 +235,7 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
print(len(index_no_updates), torch.sum(index_no_updates))
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
@ -296,6 +296,7 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
text_encoder.train()
@ -383,8 +384,8 @@ def train(args):
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
d = updated_embs - bef_epo_embs
print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
# d = updated_embs - bef_epo_embs
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name