import argparse from dataclasses import ( asdict, dataclass, ) import functools from textwrap import dedent, indent import json from pathlib import Path # from toolz import curry from typing import ( List, Optional, Sequence, Tuple, Union, ) import toml import voluptuous from voluptuous import ( Any, ExactSequence, MultipleInvalid, Object, Required, Schema, ) from transformers import CLIPTokenizer from . import train_util from .train_util import ( DreamBoothSubset, FineTuningSubset, DreamBoothDataset, FineTuningDataset, DatasetGroup, ) def add_config_arguments(parser: argparse.ArgumentParser): parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") # TODO: inherit Params class in Subset, Dataset @dataclass class BaseSubsetParams: image_dir: Optional[str] = None num_repeats: int = 1 shuffle_caption: bool = False keep_tokens: int = 0 color_aug: bool = False flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None random_crop: bool = False caption_dropout_rate: float = 0.0 caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 @dataclass class DreamBoothSubsetParams(BaseSubsetParams): is_reg: bool = False class_tokens: Optional[str] = None caption_extension: str = ".caption" @dataclass class FineTuningSubsetParams(BaseSubsetParams): metadata_file: Optional[str] = None @dataclass class BaseDatasetParams: tokenizer: CLIPTokenizer = None max_token_length: int = None resolution: Optional[Tuple[int, int]] = None debug_dataset: bool = False @dataclass class DreamBoothDatasetParams(BaseDatasetParams): batch_size: int = 1 enable_bucket: bool = False min_bucket_reso: int = 256 max_bucket_reso: int = 1024 bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 enable_bucket: bool = False min_bucket_reso: int = 256 max_bucket_reso: int = 1024 bucket_reso_steps: int = 64 bucket_no_upscale: bool = False @dataclass class SubsetBlueprint: params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] @dataclass class DatasetBlueprint: is_dreambooth: bool params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] subsets: Sequence[SubsetBlueprint] @dataclass class DatasetGroupBlueprint: datasets: Sequence[DatasetBlueprint] @dataclass class Blueprint: dataset_group: DatasetGroupBlueprint class ConfigSanitizer: # @curry @staticmethod def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: Schema(ExactSequence([klass, klass]))(value) return tuple(value) # @curry @staticmethod def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: Schema(Any(klass, ExactSequence([klass, klass])))(value) try: Schema(klass)(value) return (value, value) except: return ConfigSanitizer.__validate_and_convert_twodim(klass, value) # subset schema SUBSET_ASCENDABLE_SCHEMA = { "color_aug": bool, "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), "flip_aug": bool, "num_repeats": int, "random_crop": bool, "shuffle_caption": bool, "keep_tokens": int, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { "caption_dropout_every_n_epochs": int, "caption_dropout_rate": Any(float, int), "caption_tag_dropout_rate": Any(float, int), } # DB means DreamBooth DB_SUBSET_ASCENDABLE_SCHEMA = { "caption_extension": str, "class_tokens": str, } DB_SUBSET_DISTINCT_SCHEMA = { Required("image_dir"): str, "is_reg": bool, } # FT means FineTuning FT_SUBSET_DISTINCT_SCHEMA = { Required("metadata_file"): str, "image_dir": str, } # datasets schema DATASET_ASCENDABLE_SCHEMA = { "batch_size": int, "bucket_no_upscale": bool, "bucket_reso_steps": int, "enable_bucket": bool, "max_bucket_reso": int, "min_bucket_reso": int, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), } # options handled by argparse but not handled by user config ARGPARSE_SPECIFIC_SCHEMA = { "debug_dataset": bool, "max_token_length": Any(None, int), "prior_loss_weight": Any(float, int), } # for handling default None value of argparse ARGPARSE_NULLABLE_OPTNAMES = [ "face_crop_aug_range", "resolution", ] # prepare map because option name may differ among argparse and user config ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { "train_batch_size": "batch_size", "dataset_repeats": "num_repeats", } def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None: assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。" self.db_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, self.DB_SUBSET_DISTINCT_SCHEMA, self.DB_SUBSET_ASCENDABLE_SCHEMA, self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, ) self.ft_subset_schema = self.__merge_dict( self.SUBSET_ASCENDABLE_SCHEMA, self.FT_SUBSET_DISTINCT_SCHEMA, self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, ) self.db_dataset_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA, self.DB_SUBSET_ASCENDABLE_SCHEMA, self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, {"subsets": [self.db_subset_schema]}, ) self.ft_dataset_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA, self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, {"subsets": [self.ft_subset_schema]}, ) if support_dreambooth and support_finetuning: def validate_flex_dataset(dataset_config: dict): subsets_config = dataset_config.get("subsets", []) # check dataset meets FT style # NOTE: all FT subsets should have "metadata_file" if all(["metadata_file" in subset for subset in subsets_config]): return Schema(self.ft_dataset_schema)(dataset_config) # check dataset meets DB style # NOTE: all DB subsets should have no "metadata_file" elif all(["metadata_file" not in subset for subset in subsets_config]): return Schema(self.db_dataset_schema)(dataset_config) else: raise voluptuous.Invalid("DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。") self.dataset_schema = validate_flex_dataset elif support_dreambooth: self.dataset_schema = self.db_dataset_schema else: self.dataset_schema = self.ft_dataset_schema self.general_schema = self.__merge_dict( self.DATASET_ASCENDABLE_SCHEMA, self.SUBSET_ASCENDABLE_SCHEMA, self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, ) self.user_config_validator = Schema({ "general": self.general_schema, "datasets": [self.dataset_schema], }) self.argparse_schema = self.__merge_dict( self.general_schema, self.ARGPARSE_SPECIFIC_SCHEMA, {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, ) self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) def sanitize_user_config(self, user_config: dict) -> dict: try: return self.user_config_validator(user_config) except MultipleInvalid: # TODO: エラー発生時のメッセージをわかりやすくする print("Invalid user config / ユーザ設定の形式が正しくないようです") raise # NOTE: In nature, argument parser result is not needed to be sanitize # However this will help us to detect program bug def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: try: return self.argparse_config_validator(argparse_namespace) except MultipleInvalid: # XXX: this should be a bug print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") raise # NOTE: value would be overwritten by latter dict if there is already the same key @staticmethod def __merge_dict(*dict_list: dict) -> dict: merged = {} for schema in dict_list: # merged |= schema for k, v in schema.items(): merged[k] = v return merged class BlueprintGenerator: BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = { } def __init__(self, sanitizer: ConfigSanitizer): self.sanitizer = sanitizer # runtime_params is for parameters which is only configurable on runtime, such as tokenizer def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) # convert argparse namespace to dict like config # NOTE: it is ok to have extra entries in dict optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME argparse_config = {optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()} general_config = sanitized_user_config.get("general", {}) dataset_blueprints = [] for dataset_config in sanitized_user_config.get("datasets", []): # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets subsets = dataset_config.get("subsets", []) is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) if is_dreambooth: subset_params_klass = DreamBoothSubsetParams dataset_params_klass = DreamBoothDatasetParams else: subset_params_klass = FineTuningSubsetParams dataset_params_klass = FineTuningDatasetParams subset_blueprints = [] for subset_config in subsets: params = self.generate_params_by_fallbacks(subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]) subset_blueprints.append(SubsetBlueprint(params)) params = self.generate_params_by_fallbacks(dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]) dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints)) dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) return Blueprint(dataset_group_blueprint) @staticmethod def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME search_value = BlueprintGenerator.search_value default_params = asdict(param_klass()) param_names = default_params.keys() params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} return param_klass(**params) @staticmethod def search_value(key: str, fallbacks: Sequence[dict], default_value = None): for cand in fallbacks: value = cand.get(key) if value is not None: return value return default_value def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = [] for dataset_blueprint in dataset_group_blueprint.datasets: if dataset_blueprint.is_dreambooth: subset_klass = DreamBoothSubset dataset_klass = DreamBoothDataset else: subset_klass = FineTuningSubset dataset_klass = FineTuningDataset subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) datasets.append(dataset) # print info info = "" for i, dataset in enumerate(datasets): is_dreambooth = isinstance(dataset, DreamBoothDataset) info += dedent(f"""\ [Dataset {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} """) if dataset.enable_bucket: info += indent(dedent(f"""\ min_bucket_reso: {dataset.min_bucket_reso} max_bucket_reso: {dataset.max_bucket_reso} bucket_reso_steps: {dataset.bucket_reso_steps} bucket_no_upscale: {dataset.bucket_no_upscale} \n"""), " ") else: info += "\n" for j, subset in enumerate(dataset.subsets): info += indent(dedent(f"""\ [Subset {j} of Dataset {i}] image_dir: "{subset.image_dir}" image_count: {subset.img_count} num_repeats: {subset.num_repeats} shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} caption_dropout_rate: {subset.caption_dropout_rate} caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} color_aug: {subset.color_aug} flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} """), " ") if is_dreambooth: info += indent(dedent(f"""\ is_reg: {subset.is_reg} class_tokens: {subset.class_tokens} caption_extension: {subset.caption_extension} \n"""), " ") else: info += indent(dedent(f"""\ metadata_file: {subset.metadata_file} \n"""), " ") print(info) # make buckets first because it determines the length of dataset for i, dataset in enumerate(datasets): print(f"[Dataset {i}]") dataset.make_buckets() return DatasetGroup(datasets) def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): def extract_dreambooth_params(name: str) -> Tuple[int, str]: tokens = name.split('_') try: n_repeats = int(tokens[0]) except ValueError as e: print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") return 0, "" caption_by_folder = '_'.join(tokens[1:]) return n_repeats, caption_by_folder def generate(base_dir: Optional[str], is_reg: bool): if base_dir is None: return [] base_dir: Path = Path(base_dir) if not base_dir.is_dir(): return [] subsets_config = [] for subdir in base_dir.iterdir(): if not subdir.is_dir(): continue num_repeats, class_tokens = extract_dreambooth_params(subdir.name) if num_repeats < 1: continue subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} subsets_config.append(subset_config) return subsets_config subsets_config = [] subsets_config += generate(train_data_dir, False) subsets_config += generate(reg_data_dir, True) return subsets_config def load_user_config(file: str) -> dict: file: Path = Path(file) if not file.is_file(): raise ValueError(f"file not found / ファイルが見つかりません: {file}") if file.name.lower().endswith('.json'): try: config = json.load(file) except Exception: print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") raise elif file.name.lower().endswith('.toml'): try: config = toml.load(file) except Exception: print(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}") raise else: raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") return config # for config test if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--support_dreambooth", action="store_true") parser.add_argument("--support_finetuning", action="store_true") parser.add_argument("--support_dropout", action="store_true") parser.add_argument("dataset_config") config_args, remain = parser.parse_known_args() parser = argparse.ArgumentParser() train_util.add_dataset_arguments(parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) train_util.add_training_arguments(parser, config_args.support_dreambooth) argparse_namespace = parser.parse_args(remain) train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) print("[argparse_namespace]") print(vars(argparse_namespace)) user_config = load_user_config(config_args.dataset_config) print("\n[user_config]") print(user_config) sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) sanitized_user_config = sanitizer.sanitize_user_config(user_config) print("\n[sanitized_user_config]") print(sanitized_user_config) blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) print("\n[blueprint]") print(blueprint)