Update diffusers fine tuning code
This commit is contained in:
parent
dd241f2142
commit
91eff4da17
@ -1,3 +1,10 @@
|
|||||||
# Diffusers Fine Tuning
|
# Diffusers Fine Tuning
|
||||||
|
|
||||||
This subfolder provide all the required toold to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29
|
This subfolder provide all the required tools to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29
|
||||||
|
|
||||||
|
## Releases
|
||||||
|
|
||||||
|
11/23 (v3):
|
||||||
|
- Added WD14Tagger tagging script.
|
||||||
|
- A log output function has been added to the fine_tune.py. Also, fixed the double shuffling of data.
|
||||||
|
- Fixed misspelling of options for each script (caption_extention→caption_extension will work for the time being, even if it remains outdated).
|
||||||
|
@ -13,11 +13,13 @@ def clean_tags(image_key, tags):
|
|||||||
# replace '_' to ' '
|
# replace '_' to ' '
|
||||||
tags = tags.replace('_', ' ')
|
tags = tags.replace('_', ' ')
|
||||||
|
|
||||||
# remove rating
|
# remove rating: deepdanbooruのみ
|
||||||
tokens = tags.split(", rating")
|
tokens = tags.split(", rating")
|
||||||
if len(tokens) == 1:
|
if len(tokens) == 1:
|
||||||
print("no rating:")
|
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
||||||
print(f"{image_key} {tags}")
|
# print("no rating:")
|
||||||
|
# print(f"{image_key} {tags}")
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
if len(tokens) > 2:
|
if len(tokens) > 2:
|
||||||
print("multiple ratings:")
|
print("multiple ratings:")
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# v2: select precision for saved checkpoint
|
# v2: select precision for saved checkpoint
|
||||||
|
# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset)
|
||||||
|
|
||||||
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
||||||
# License:
|
# License:
|
||||||
@ -22,12 +22,12 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import itertools
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import json
|
import json
|
||||||
import importlib
|
import importlib
|
||||||
|
import time
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@ -159,7 +159,7 @@ class FineTuningDataset(torch.utils.data.Dataset):
|
|||||||
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
|
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
|
||||||
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
|
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
|
||||||
|
|
||||||
# 77以上の時は "<CLS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<CLS>...<EOS>"の三連に変換する
|
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||||||
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||||||
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
||||||
input_ids = input_ids.squeeze(0)
|
input_ids = input_ids.squeeze(0)
|
||||||
@ -242,7 +242,14 @@ def train(args):
|
|||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
print("prepare accelerator")
|
||||||
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
|
if args.logging_dir is None:
|
||||||
|
log_with = None
|
||||||
|
logging_dir = None
|
||||||
|
else:
|
||||||
|
log_with = "tensorboard"
|
||||||
|
logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||||
|
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
@ -304,7 +311,7 @@ def train(args):
|
|||||||
text_encoder.requires_grad_(False) # text encoderは学習しない
|
text_encoder.requires_grad_(False) # text encoderは学習しない
|
||||||
text_encoder.eval()
|
text_encoder.eval()
|
||||||
else:
|
else:
|
||||||
unet.to(accelerator.device, dtype=weight_dtype)
|
unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
unet.eval()
|
unet.eval()
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||||
@ -314,7 +321,10 @@ def train(args):
|
|||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.requires_grad_(True)
|
m.requires_grad_(True)
|
||||||
params_to_optimize = itertools.chain(*[m.parameters() for m in training_models])
|
params = []
|
||||||
|
for m in training_models:
|
||||||
|
params.extend(m.parameters())
|
||||||
|
params_to_optimize = params
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("prepare optimizer, data loader etc.")
|
print("prepare optimizer, data loader etc.")
|
||||||
@ -337,11 +347,11 @@ def train(args):
|
|||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=n_workers)
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
args.lr_scheduler, optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_warmup_steps=args.lr_warmup_steps)
|
"constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if fine_tuning:
|
if fine_tuning:
|
||||||
@ -390,7 +400,7 @@ def train(args):
|
|||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
with accelerator.accumulate(training_models[0]):
|
with accelerator.accumulate(training_models[0]): # ここはこれでいいのか……?
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
@ -411,7 +421,7 @@ def train(args):
|
|||||||
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
||||||
|
|
||||||
if args.max_token_length is not None:
|
if args.max_token_length is not None:
|
||||||
# <CLS>...<EOS> の三連を <CLS>...<EOS> へ戻す
|
# <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||||
sts_list = [encoder_hidden_states[:, 0].unsqueeze(1)]
|
sts_list = [encoder_hidden_states[:, 0].unsqueeze(1)]
|
||||||
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
||||||
sts_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2])
|
sts_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2])
|
||||||
@ -436,7 +446,9 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
params_to_clip = itertools.chain(*[m.parameters() for m in training_models])
|
params_to_clip = []
|
||||||
|
for m in training_models:
|
||||||
|
params_to_clip.extend(m.parameters())
|
||||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@ -449,15 +461,22 @@ def train(args):
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
current_loss = loss.detach().item() * b_size
|
current_loss = loss.detach().item() * b_size
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
loss_total += current_loss
|
loss_total += current_loss
|
||||||
avr_loss = loss_total / (step+1)
|
avr_loss = loss_total / (step+1)
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
progress_bar.set_postfix(**logs)
|
progress_bar.set_postfix(**logs)
|
||||||
# accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
||||||
|
accelerator.log(logs, step=epoch+1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if args.save_every_n_epochs is not None:
|
if args.save_every_n_epochs is not None:
|
||||||
@ -843,7 +862,8 @@ if __name__ == '__main__':
|
|||||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||||
parser.add_argument("--debug_dataset", action="store_true",
|
parser.add_argument("--debug_dataset", action="store_true",
|
||||||
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
|
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
|
||||||
parser.add_argument("--lr_scheduler", type=str, default="constant", help="scheduler to use for learning rate: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup")
|
parser.add_argument("--logging_dir", type=str, default=None,
|
||||||
parser.add_argument("--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler.")
|
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
@ -48,7 +48,7 @@ def main(args):
|
|||||||
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
|
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
|
||||||
|
|
||||||
for (image_path, _), caption in zip(path_imgs, captions):
|
for (image_path, _), caption in zip(path_imgs, captions):
|
||||||
with open(os.path.splitext(image_path)[0] + args.caption_extention, "wt", encoding='utf-8') as f:
|
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||||
f.write(caption + "\n")
|
f.write(caption + "\n")
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(image_path, caption)
|
print(image_path, caption)
|
||||||
@ -76,7 +76,9 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
parser.add_argument("caption_weights", type=str,
|
parser.add_argument("caption_weights", type=str,
|
||||||
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
|
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
|
||||||
parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption file / 出力されるキャプションファイルの拡張子")
|
parser.add_argument("--caption_extention", type=str, default=None,
|
||||||
|
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||||
|
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||||
parser.add_argument("--beam_search", action="store_true",
|
parser.add_argument("--beam_search", action="store_true",
|
||||||
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
@ -87,4 +89,9 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# スペルミスしていたオプションを復元する
|
||||||
|
if args.caption_extention is not None:
|
||||||
|
args.caption_extension = args.caption_extention
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -24,7 +24,7 @@ def main(args):
|
|||||||
|
|
||||||
print("merge caption texts to metadata json.")
|
print("merge caption texts to metadata json.")
|
||||||
for image_path in tqdm(image_paths):
|
for image_path in tqdm(image_paths):
|
||||||
caption_path = os.path.splitext(image_path)[0] + args.caption_extention
|
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
|
||||||
with open(caption_path, "rt", encoding='utf-8') as f:
|
with open(caption_path, "rt", encoding='utf-8') as f:
|
||||||
caption = f.readlines()[0].strip()
|
caption = f.readlines()[0].strip()
|
||||||
|
|
||||||
@ -54,8 +54,15 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
parser.add_argument("--in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
parser.add_argument("--in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||||
parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption file / 読み込むキャプションファイルの拡張子")
|
parser.add_argument("--caption_extention", type=str, default=None,
|
||||||
|
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||||
|
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
|
||||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# スペルミスしていたオプションを復元する
|
||||||
|
if args.caption_extention is not None:
|
||||||
|
args.caption_extension = args.caption_extention
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
107
diffusers_fine_tuning/tag_images_by_wd14_tagger.py
Normal file
107
diffusers_fine_tuning/tag_images_by_wd14_tagger.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# このスクリプトのライセンスは、Apache License 2.0とします
|
||||||
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
from Utils import dbimutils
|
||||||
|
|
||||||
|
|
||||||
|
# from wd14 tagger
|
||||||
|
IMAGE_SIZE = 448
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
|
print("loading model and labels")
|
||||||
|
model = load_model(args.model)
|
||||||
|
|
||||||
|
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||||
|
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||||
|
with open(args.tag_csv, "r", encoding="utf-8") as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
l = [row for row in reader]
|
||||||
|
header = l[0] # tag_id,name,category,count
|
||||||
|
rows = l[1:]
|
||||||
|
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
|
||||||
|
|
||||||
|
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
|
||||||
|
|
||||||
|
# 推論する
|
||||||
|
def run_batch(path_imgs):
|
||||||
|
imgs = np.array([im for _, im in path_imgs])
|
||||||
|
|
||||||
|
probs = model(imgs, training=False)
|
||||||
|
probs = probs.numpy()
|
||||||
|
|
||||||
|
for (image_path, _), prob in zip(path_imgs, probs):
|
||||||
|
# 最初の4つはratingなので無視する
|
||||||
|
# # First 4 labels are actually ratings: pick one with argmax
|
||||||
|
# ratings_names = label_names[:4]
|
||||||
|
# rating_index = ratings_names["probs"].argmax()
|
||||||
|
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
|
||||||
|
|
||||||
|
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
|
||||||
|
# Everything else is tags: pick any where prediction confidence > threshold
|
||||||
|
tag_text = ""
|
||||||
|
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
|
||||||
|
if p >= args.thresh:
|
||||||
|
tag_text += ", " + tags[i]
|
||||||
|
|
||||||
|
if len(tag_text) > 0:
|
||||||
|
tag_text = tag_text[2:] # 最初の ", " を消す
|
||||||
|
|
||||||
|
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
|
||||||
|
f.write(tag_text + '\n')
|
||||||
|
if args.debug:
|
||||||
|
print(image_path, tag_text)
|
||||||
|
|
||||||
|
b_imgs = []
|
||||||
|
for image_path in tqdm(image_paths):
|
||||||
|
img = dbimutils.smart_imread(image_path)
|
||||||
|
img = dbimutils.smart_24bit(img)
|
||||||
|
img = dbimutils.make_square(img, IMAGE_SIZE)
|
||||||
|
img = dbimutils.smart_resize(img, IMAGE_SIZE)
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
b_imgs.append((image_path, img))
|
||||||
|
|
||||||
|
if len(b_imgs) >= args.batch_size:
|
||||||
|
run_batch(b_imgs)
|
||||||
|
b_imgs.clear()
|
||||||
|
if len(b_imgs) > 0:
|
||||||
|
run_batch(b_imgs)
|
||||||
|
|
||||||
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
|
parser.add_argument("--model", type=str, default="networks/ViTB16_11_03_2022_07h05m53s",
|
||||||
|
help="model path to load / 読み込むモデルファイル")
|
||||||
|
parser.add_argument("--tag_csv", type=str, default="2022_0000_0899_6549/selected_tags.csv",
|
||||||
|
help="csv file for tags / タグ一覧のCSVファイル")
|
||||||
|
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
|
parser.add_argument("--caption_extention", type=str, default=None,
|
||||||
|
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
|
||||||
|
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
||||||
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# スペルミスしていたオプションを復元する
|
||||||
|
if args.caption_extention is not None:
|
||||||
|
args.caption_extension = args.caption_extention
|
||||||
|
|
||||||
|
main(args)
|
Loading…
Reference in New Issue
Block a user