Update to latest sd-scripts updates
This commit is contained in:
parent
ccae80186a
commit
1c8d901c3b
@ -205,6 +205,13 @@ This will store your a backup file with your current locally installed pip packa
|
||||
- `( )`, `(xxxx:1.2)` and `[ ]` can be used.
|
||||
- Fix exception on training model in diffusers format with `train_network.py` Thanks to orenwang! [#290](https://github.com/kohya-ss/sd-scripts/pull/290)
|
||||
- Add warning if you are about to overwrite an existing model: https://github.com/bmaltais/kohya_ss/issues/404
|
||||
- Add `--vae_batch_size` for faster latents caching to each training script. This batches VAE calls.
|
||||
- Please start with`2` or `4` depending on the size of VRAM.
|
||||
- Fix a number of training steps with `--gradient_accumulation_steps` and `--max_train_epochs`. Thanks to tsukimiya!
|
||||
- Extract parser setup to external scripts. Thanks to robertsmieja!
|
||||
- Fix an issue without `.npz` and with `--full_path` in training.
|
||||
- Support extensions with upper cases for images for not Windows environment.
|
||||
- Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki!
|
||||
* 2023/03/19 (v21.2.5):
|
||||
- Fix basic captioning logic
|
||||
- Add possibility to not train TE in Dreamboot by setting `Step text encoder training` to -1.
|
||||
|
@ -107,6 +107,7 @@ def save_configuration(
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -214,6 +215,7 @@ def open_configuration(
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -303,6 +305,7 @@ def train_model(
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -480,6 +483,7 @@ def train_model(
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
noise_offset=noise_offset,
|
||||
additional_parameters=additional_parameters,
|
||||
vae_batch_size=vae_batch_size,
|
||||
)
|
||||
|
||||
run_cmd += run_cmd_sample(
|
||||
@ -686,6 +690,7 @@ def dreambooth_tab(
|
||||
caption_dropout_rate,
|
||||
noise_offset,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -786,6 +791,7 @@ def dreambooth_tab(
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
14
fine_tune.py
14
fine_tune.py
@ -138,7 +138,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae)
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
@ -194,7 +194,7 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
@ -240,7 +240,7 @@ def train(args):
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
@ -387,7 +387,7 @@ def train(args):
|
||||
print("model saved.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
@ -400,6 +400,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
|
@ -163,13 +163,19 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
if len(unknown) == 1:
|
||||
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
||||
|
@ -133,7 +133,7 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
|
||||
@ -153,6 +153,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
|
@ -127,7 +127,7 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
@ -141,5 +141,11 @@ if __name__ == '__main__':
|
||||
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -46,7 +46,7 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
@ -61,6 +61,12 @@ if __name__ == '__main__':
|
||||
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
|
@ -47,7 +47,7 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
@ -61,5 +61,11 @@ if __name__ == '__main__':
|
||||
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -229,7 +229,7 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
@ -257,5 +257,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--skip_existing", action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -173,7 +173,7 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
|
||||
@ -191,6 +191,12 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
|
@ -104,7 +104,7 @@ def save_configuration(
|
||||
sample_every_n_epochs,
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
additional_parameters,vae_batch_size,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -217,7 +217,7 @@ def open_configuration(
|
||||
sample_every_n_epochs,
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
additional_parameters,vae_batch_size,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -312,7 +312,7 @@ def train_model(
|
||||
sample_every_n_epochs,
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
additional_parameters,vae_batch_size,
|
||||
):
|
||||
if check_if_model_exist(output_name, output_dir, save_model_as):
|
||||
return
|
||||
@ -470,6 +470,7 @@ def train_model(
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
noise_offset=noise_offset,
|
||||
additional_parameters=additional_parameters,
|
||||
vae_batch_size=vae_batch_size,
|
||||
)
|
||||
|
||||
run_cmd += run_cmd_sample(
|
||||
@ -686,6 +687,7 @@ def finetune_tab():
|
||||
caption_dropout_rate,
|
||||
noise_offset,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -780,6 +782,7 @@ def finetune_tab():
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
]
|
||||
|
||||
button_run.click(train_model, inputs=settings_list)
|
||||
|
@ -2690,7 +2690,7 @@ def main(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--v2", action='store_true', help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
||||
@ -2786,5 +2786,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
|
||||
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -928,6 +928,12 @@ def gradio_advanced_training():
|
||||
caption_dropout_rate = gr.Slider(
|
||||
label='Rate of caption dropout', value=0, minimum=0, maximum=1
|
||||
)
|
||||
vae_batch_size = gr.Slider(
|
||||
label='VAE batch size',
|
||||
minimum=0,
|
||||
maximum=32,
|
||||
value=0
|
||||
)
|
||||
with gr.Row():
|
||||
save_state = gr.Checkbox(label='Save training state', value=False)
|
||||
resume = gr.Textbox(
|
||||
@ -972,6 +978,7 @@ def gradio_advanced_training():
|
||||
caption_dropout_rate,
|
||||
noise_offset,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
)
|
||||
|
||||
|
||||
@ -998,8 +1005,11 @@ def run_cmd_advanced_training(**kwargs):
|
||||
f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"'
|
||||
if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
|
||||
else '',
|
||||
f' --caption_dropout_rate="{kwargs.get("caption_dropout_rate", "")}"'
|
||||
if float(kwargs.get('caption_dropout_rate', 0)) > 0
|
||||
f' --caption_dropout_every_n_epochs="{int(kwargs.get("caption_dropout_every_n_epochs", 0))}"'
|
||||
if int(kwargs.get('caption_dropout_every_n_epochs', 0)) > 0
|
||||
else '',
|
||||
f' --vae_batch_size="{kwargs.get("vae_batch_size", 0)}"'
|
||||
if int(kwargs.get('vae_batch_size', 0)) > 0
|
||||
else '',
|
||||
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
|
||||
if int(kwargs.get('bucket_reso_steps', 64)) >= 1
|
||||
|
@ -73,8 +73,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
|
||||
|
||||
# region dataset
|
||||
|
||||
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
||||
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
|
||||
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
@ -675,10 +674,19 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def is_latent_cacheable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
def cache_latents(self, vae):
|
||||
# TODO ここを高速化したい
|
||||
def cache_latents(self, vae, vae_batch_size=1):
|
||||
# ちょっと速くした
|
||||
print("caching latents.")
|
||||
for info in tqdm(self.image_data.values()):
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
for info in image_infos:
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None:
|
||||
@ -689,18 +697,42 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||
continue
|
||||
|
||||
image = self.load_image(info.absolute_path)
|
||||
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
img_tensor = self.image_transforms(image)
|
||||
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
||||
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
||||
batch.append(info)
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
# iterate batches
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
images = []
|
||||
for info in batch:
|
||||
image = self.load_image(info.absolute_path)
|
||||
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
||||
image = self.image_transforms(image)
|
||||
images.append(image)
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
for info, latent in zip(batch, latents):
|
||||
info.latents = latent
|
||||
|
||||
if subset.flip_aug:
|
||||
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
||||
img_tensor = self.image_transforms(image)
|
||||
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
||||
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
for info, latent in zip(batch, latents):
|
||||
info.latents_flipped = latent
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
image = Image.open(image_path)
|
||||
@ -1197,6 +1229,10 @@ class FineTuningDataset(BaseDataset):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
# if not full path, check image_dir. if image_dir is None, return None
|
||||
if subset.image_dir is None:
|
||||
return None, None
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
@ -1237,10 +1273,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
# for dataset in self.datasets:
|
||||
# dataset.make_buckets()
|
||||
|
||||
def cache_latents(self, vae):
|
||||
def cache_latents(self, vae, vae_batch_size=1):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
print(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae)
|
||||
dataset.cache_latents(vae, vae_batch_size)
|
||||
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||
@ -1986,6 +2022,7 @@ def add_dataset_arguments(
|
||||
action="store_true",
|
||||
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)",
|
||||
)
|
||||
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
|
||||
)
|
||||
|
@ -123,7 +123,7 @@ def save_configuration(
|
||||
sample_every_n_epochs,
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
additional_parameters,vae_batch_size,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -240,7 +240,7 @@ def open_configuration(
|
||||
sample_every_n_epochs,
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
additional_parameters,vae_batch_size,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -347,7 +347,7 @@ def train_model(
|
||||
sample_every_n_epochs,
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
additional_parameters,vae_batch_size,
|
||||
):
|
||||
print_only_bool = True if print_only.get('label') == 'True' else False
|
||||
|
||||
@ -589,6 +589,7 @@ def train_model(
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
noise_offset=noise_offset,
|
||||
additional_parameters=additional_parameters,
|
||||
vae_batch_size=vae_batch_size,
|
||||
)
|
||||
|
||||
run_cmd += run_cmd_sample(
|
||||
@ -891,6 +892,7 @@ def lora_tab(
|
||||
caption_dropout_rate,
|
||||
noise_offset,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -1008,6 +1010,7 @@ def lora_tab(
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
@ -24,9 +24,16 @@ def main(file):
|
||||
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.file)
|
||||
|
@ -113,7 +113,7 @@ def svd(args):
|
||||
else:
|
||||
mat = mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat.to("cuda"))
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
@ -122,18 +122,18 @@ def svd(args):
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
# hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
# low_val = -hi_val
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
# U = U.clamp(low_val, hi_val)
|
||||
# Vh = Vh.clamp(low_val, hi_val)
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
|
||||
U = U.to("cuda").contiguous()
|
||||
Vh = Vh.to("cuda").contiguous()
|
||||
U = U.to("cpu").contiguous()
|
||||
Vh = Vh.to("cpu").contiguous()
|
||||
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
|
||||
@ -162,7 +162,7 @@ def svd(args):
|
||||
print(f"LoRA weights are saved to: {args.save_to}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
@ -179,5 +179,11 @@ if __name__ == '__main__':
|
||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
svd(args)
|
||||
|
@ -105,7 +105,7 @@ def interrogate(args):
|
||||
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
@ -118,5 +118,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--clip_skip", type=int, default=None,
|
||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
interrogate(args)
|
||||
|
@ -197,7 +197,7 @@ def merge(args):
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
@ -214,5 +214,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--ratios", type=float, nargs='*',
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
|
@ -158,7 +158,7 @@ def merge(args):
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
@ -175,5 +175,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--ratios", type=float, nargs='*',
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
|
@ -208,18 +208,28 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
||||
|
||||
with torch.no_grad():
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
weight_name = None
|
||||
if 'lora_down' in key:
|
||||
block_down_name = key.split(".")[0]
|
||||
weight_name = key.split(".")[-1]
|
||||
lora_down_weight = value
|
||||
if 'lora_up' in key:
|
||||
block_up_name = key.split(".")[0]
|
||||
lora_up_weight = value
|
||||
else:
|
||||
continue
|
||||
|
||||
# find corresponding lora_up and alpha
|
||||
block_up_name = block_down_name
|
||||
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
|
||||
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
|
||||
|
||||
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
||||
|
||||
if (block_down_name == block_up_name) and weights_loaded:
|
||||
if weights_loaded:
|
||||
|
||||
conv2d = (len(lora_down_weight.size()) == 4)
|
||||
if lora_alpha is None:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = lora_alpha/lora_down_weight.size()[0]
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
@ -311,7 +321,7 @@ def resize(args):
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
@ -329,7 +339,12 @@ if __name__ == '__main__':
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
||||
parser.add_argument("--dynamic_param", type=float, default=None,
|
||||
help="Specify target for dynamic reduction")
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
resize(args)
|
||||
|
@ -76,7 +76,11 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
down_weight = down_weight.to(device)
|
||||
|
||||
# W <- W + U * D
|
||||
scale = (alpha / network_dim).to(device)
|
||||
scale = (alpha / network_dim)
|
||||
|
||||
if device: # and isinstance(scale, torch.Tensor):
|
||||
scale = scale.to(device)
|
||||
|
||||
if not conv2d: # linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
elif kernel_size == (1, 1):
|
||||
@ -115,12 +119,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
|
||||
Vh = Vh[:module_new_rank, :]
|
||||
|
||||
# dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
# hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
# low_val = -hi_val
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
# U = U.clamp(low_val, hi_val)
|
||||
# Vh = Vh.clamp(low_val, hi_val)
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
||||
@ -160,7 +164,7 @@ def merge(args):
|
||||
save_to_file(args.save_to, state_dict, save_dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
||||
@ -178,5 +182,11 @@ if __name__ == '__main__':
|
||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
|
@ -112,7 +112,7 @@ def save_configuration(
|
||||
sample_every_n_epochs,
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
additional_parameters,vae_batch_size,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -225,7 +225,7 @@ def open_configuration(
|
||||
sample_every_n_epochs,
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
additional_parameters,vae_batch_size,
|
||||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
@ -320,7 +320,7 @@ def train_model(
|
||||
sample_every_n_epochs,
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
additional_parameters,vae_batch_size,
|
||||
):
|
||||
if pretrained_model_name_or_path == '':
|
||||
msgbox('Source model information is missing')
|
||||
@ -511,6 +511,7 @@ def train_model(
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
noise_offset=noise_offset,
|
||||
additional_parameters=additional_parameters,
|
||||
vae_batch_size=vae_batch_size,
|
||||
)
|
||||
run_cmd += f' --token_string="{token_string}"'
|
||||
run_cmd += f' --init_word="{init_word}"'
|
||||
@ -770,6 +771,7 @@ def ti_tab(
|
||||
caption_dropout_rate,
|
||||
noise_offset,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
) = gradio_advanced_training()
|
||||
color_aug.change(
|
||||
color_aug_changed,
|
||||
@ -876,6 +878,7 @@ def ti_tab(
|
||||
sample_sampler,
|
||||
sample_prompts,
|
||||
additional_parameters,
|
||||
vae_batch_size,
|
||||
]
|
||||
|
||||
button_open_config.click(
|
||||
|
@ -13,12 +13,18 @@ def canny(args):
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", type=str, default=None, help="input path")
|
||||
parser.add_argument("--output", type=str, default=None, help="output path")
|
||||
parser.add_argument("--thres1", type=int, default=32, help="thres1")
|
||||
parser.add_argument("--thres2", type=int, default=224, help="thres2")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
canny(args)
|
||||
|
@ -61,7 +61,7 @@ def convert(args):
|
||||
print(f"model saved.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v1", action='store_true',
|
||||
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
|
||||
@ -84,6 +84,11 @@ if __name__ == '__main__':
|
||||
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
|
||||
parser.add_argument("model_to_save", type=str, default=None,
|
||||
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
|
@ -214,7 +214,7 @@ def process(args):
|
||||
buf.tofile(f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
|
||||
parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
|
||||
@ -234,6 +234,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--multiple_faces", action="store_true",
|
||||
help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
|
||||
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args)
|
||||
|
@ -98,7 +98,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
||||
|
||||
|
||||
def main():
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
|
||||
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
|
||||
@ -113,6 +113,12 @@ def main():
|
||||
parser.add_argument('--copy_associated_files', action='store_true',
|
||||
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
|
||||
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
|
||||
|
12
train_db.py
12
train_db.py
@ -114,7 +114,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae)
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
@ -159,7 +159,7 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
if args.stop_text_encoder_training is None:
|
||||
@ -381,7 +381,7 @@ def train(args):
|
||||
print("model saved.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
@ -403,6 +403,12 @@ if __name__ == "__main__":
|
||||
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
|
@ -139,7 +139,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae)
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
@ -196,7 +196,7 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
if is_main_process:
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
@ -644,7 +644,7 @@ def train(args):
|
||||
print("model saved.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
@ -687,6 +687,12 @@ if __name__ == "__main__":
|
||||
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
|
@ -228,7 +228,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae)
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
@ -257,7 +257,7 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
@ -526,7 +526,7 @@ def load_weights(file):
|
||||
return emb
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
@ -565,6 +565,12 @@ if __name__ == "__main__":
|
||||
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user