KohyaSS/library/config_util.py
2023-03-28 11:54:42 -04:00

537 lines
18 KiB
Python

import argparse
from dataclasses import (
asdict,
dataclass,
)
import functools
import random
from textwrap import dedent, indent
import json
from pathlib import Path
# from toolz import curry
from typing import (
List,
Optional,
Sequence,
Tuple,
Union,
)
import toml
import voluptuous
from voluptuous import (
Any,
ExactSequence,
MultipleInvalid,
Object,
Required,
Schema,
)
from transformers import CLIPTokenizer
from . import train_util
from .train_util import (
DreamBoothSubset,
FineTuningSubset,
DreamBoothDataset,
FineTuningDataset,
DatasetGroup,
)
def add_config_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
# TODO: inherit Params class in Subset, Dataset
@dataclass
class BaseSubsetParams:
image_dir: Optional[str] = None
num_repeats: int = 1
shuffle_caption: bool = False
keep_tokens: int = 0
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
random_crop: bool = False
caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
@dataclass
class DreamBoothSubsetParams(BaseSubsetParams):
is_reg: bool = False
class_tokens: Optional[str] = None
caption_extension: str = ".caption"
@dataclass
class FineTuningSubsetParams(BaseSubsetParams):
metadata_file: Optional[str] = None
@dataclass
class BaseDatasetParams:
tokenizer: CLIPTokenizer = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
debug_dataset: bool = False
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
batch_size: int = 1
enable_bucket: bool = False
min_bucket_reso: int = 256
max_bucket_reso: int = 1024
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0
@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
batch_size: int = 1
enable_bucket: bool = False
min_bucket_reso: int = 256
max_bucket_reso: int = 1024
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
@dataclass
class SubsetBlueprint:
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
@dataclass
class DatasetBlueprint:
is_dreambooth: bool
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
subsets: Sequence[SubsetBlueprint]
@dataclass
class DatasetGroupBlueprint:
datasets: Sequence[DatasetBlueprint]
@dataclass
class Blueprint:
dataset_group: DatasetGroupBlueprint
class ConfigSanitizer:
# @curry
@staticmethod
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
Schema(ExactSequence([klass, klass]))(value)
return tuple(value)
# @curry
@staticmethod
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
Schema(Any(klass, ExactSequence([klass, klass])))(value)
try:
Schema(klass)(value)
return (value, value)
except:
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
# subset schema
SUBSET_ASCENDABLE_SCHEMA = {
"color_aug": bool,
"face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
"flip_aug": bool,
"num_repeats": int,
"random_crop": bool,
"shuffle_caption": bool,
"keep_tokens": int,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
"caption_dropout_every_n_epochs": int,
"caption_dropout_rate": Any(float, int),
"caption_tag_dropout_rate": Any(float, int),
}
# DB means DreamBooth
DB_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
"class_tokens": str,
}
DB_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
"is_reg": bool,
}
# FT means FineTuning
FT_SUBSET_DISTINCT_SCHEMA = {
Required("metadata_file"): str,
"image_dir": str,
}
# datasets schema
DATASET_ASCENDABLE_SCHEMA = {
"batch_size": int,
"bucket_no_upscale": bool,
"bucket_reso_steps": int,
"enable_bucket": bool,
"max_bucket_reso": int,
"min_bucket_reso": int,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
}
# options handled by argparse but not handled by user config
ARGPARSE_SPECIFIC_SCHEMA = {
"debug_dataset": bool,
"max_token_length": Any(None, int),
"prior_loss_weight": Any(float, int),
}
# for handling default None value of argparse
ARGPARSE_NULLABLE_OPTNAMES = [
"face_crop_aug_range",
"resolution",
]
# prepare map because option name may differ among argparse and user config
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
"train_batch_size": "batch_size",
"dataset_repeats": "num_repeats",
}
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
self.DB_SUBSET_DISTINCT_SCHEMA,
self.DB_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.ft_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
self.FT_SUBSET_DISTINCT_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.db_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.DB_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
{"subsets": [self.db_subset_schema]},
)
self.ft_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
{"subsets": [self.ft_subset_schema]},
)
if support_dreambooth and support_finetuning:
def validate_flex_dataset(dataset_config: dict):
subsets_config = dataset_config.get("subsets", [])
# check dataset meets FT style
# NOTE: all FT subsets should have "metadata_file"
if all(["metadata_file" in subset for subset in subsets_config]):
return Schema(self.ft_dataset_schema)(dataset_config)
# check dataset meets DB style
# NOTE: all DB subsets should have no "metadata_file"
elif all(["metadata_file" not in subset for subset in subsets_config]):
return Schema(self.db_dataset_schema)(dataset_config)
else:
raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。")
self.dataset_schema = validate_flex_dataset
elif support_dreambooth:
self.dataset_schema = self.db_dataset_schema
else:
self.dataset_schema = self.ft_dataset_schema
self.general_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.user_config_validator = Schema({
"general": self.general_schema,
"datasets": [self.dataset_schema],
})
self.argparse_schema = self.__merge_dict(
self.general_schema,
self.ARGPARSE_SPECIFIC_SCHEMA,
{optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
{a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
)
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
def sanitize_user_config(self, user_config: dict) -> dict:
try:
return self.user_config_validator(user_config)
except MultipleInvalid:
# TODO: エラー発生時のメッセージをわかりやすくする
print("Invalid user config / ユーザ設定の形式が正しくないようです")
raise
# NOTE: In nature, argument parser result is not needed to be sanitize
# However this will help us to detect program bug
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
try:
return self.argparse_config_validator(argparse_namespace)
except MultipleInvalid:
# XXX: this should be a bug
print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
raise
# NOTE: value would be overwritten by latter dict if there is already the same key
@staticmethod
def __merge_dict(*dict_list: dict) -> dict:
merged = {}
for schema in dict_list:
# merged |= schema
for k, v in schema.items():
merged[k] = v
return merged
class BlueprintGenerator:
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {
}
def __init__(self, sanitizer: ConfigSanitizer):
self.sanitizer = sanitizer
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
# convert argparse namespace to dict like config
# NOTE: it is ok to have extra entries in dict
optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()}
general_config = sanitized_user_config.get("general", {})
dataset_blueprints = []
for dataset_config in sanitized_user_config.get("datasets", []):
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
subsets = dataset_config.get("subsets", [])
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
if is_dreambooth:
subset_params_klass = DreamBoothSubsetParams
dataset_params_klass = DreamBoothDatasetParams
else:
subset_params_klass = FineTuningSubsetParams
dataset_params_klass = FineTuningDatasetParams
subset_blueprints = []
for subset_config in subsets:
params = self.generate_params_by_fallbacks(subset_params_klass,
[subset_config, dataset_config, general_config, argparse_config, runtime_params])
subset_blueprints.append(SubsetBlueprint(params))
params = self.generate_params_by_fallbacks(dataset_params_klass,
[dataset_config, general_config, argparse_config, runtime_params])
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints))
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
return Blueprint(dataset_group_blueprint)
@staticmethod
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
search_value = BlueprintGenerator.search_value
default_params = asdict(param_klass())
param_names = default_params.keys()
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
return param_klass(**params)
@staticmethod
def search_value(key: str, fallbacks: Sequence[dict], default_value = None):
for cand in fallbacks:
value = cand.get(key)
if value is not None:
return value
return default_value
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
datasets.append(dataset)
# print info
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
info += dedent(f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
""")
if dataset.enable_bucket:
info += indent(dedent(f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""), " ")
else:
info += "\n"
for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""), " ")
if is_dreambooth:
info += indent(dedent(f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
else:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")
print(info)
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
return DatasetGroup(datasets)
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
tokens = name.split('_')
try:
n_repeats = int(tokens[0])
except ValueError as e:
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
return 0, ""
caption_by_folder = '_'.join(tokens[1:])
return n_repeats, caption_by_folder
def generate(base_dir: Optional[str], is_reg: bool):
if base_dir is None:
return []
base_dir: Path = Path(base_dir)
if not base_dir.is_dir():
return []
subsets_config = []
for subdir in base_dir.iterdir():
if not subdir.is_dir():
continue
num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
if num_repeats < 1:
continue
subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
subsets_config.append(subset_config)
return subsets_config
subsets_config = []
subsets_config += generate(train_data_dir, False)
subsets_config += generate(reg_data_dir, True)
return subsets_config
def load_user_config(file: str) -> dict:
file: Path = Path(file)
if not file.is_file():
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
if file.name.lower().endswith('.json'):
try:
config = json.load(file)
except Exception:
print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
raise
elif file.name.lower().endswith('.toml'):
try:
config = toml.load(file)
except Exception:
print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
raise
else:
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
return config
# for config test
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--support_dreambooth", action="store_true")
parser.add_argument("--support_finetuning", action="store_true")
parser.add_argument("--support_dropout", action="store_true")
parser.add_argument("dataset_config")
config_args, remain = parser.parse_known_args()
parser = argparse.ArgumentParser()
train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
train_util.add_training_arguments(parser, config_args.support_dreambooth)
argparse_namespace = parser.parse_args(remain)
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
print("[argparse_namespace]")
print(vars(argparse_namespace))
user_config = load_user_config(config_args.dataset_config)
print("\n[user_config]")
print(user_config)
sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
print("\n[sanitized_user_config]")
print(sanitized_user_config)
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
print("\n[blueprint]")
print(blueprint)