Update finetune_py to V4

This commit is contained in:
bmaltais 2022-12-02 12:48:43 -05:00
parent 621dabcadf
commit 95a694a3fd
6 changed files with 422 additions and 116 deletions

View File

@ -71,6 +71,7 @@ def clean_caption(caption):
replaced = bef != caption replaced = bef != caption
return caption return caption
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
@ -95,16 +96,16 @@ def main(args):
return return
tags = metadata[image_key].get('tags') tags = metadata[image_key].get('tags')
caption = metadata[image_key].get('caption')
if tags is None: if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_path}") print(f"image does not have tags / メタデータにタグがありません: {image_path}")
return else:
metadata[image_key]['tags'] = clean_tags(image_key, tags)
caption = metadata[image_key].get('caption')
if caption is None: if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_path}") print(f"image does not have caption / メタデータにキャプションがありません: {image_path}")
return else:
metadata[image_key]['caption'] = clean_caption(caption)
metadata[image_key]['tags'] = clean_tags(image_key, tags)
metadata[image_key]['caption'] = clean_caption(caption)
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") print(f"writing metadata: {args.out_json}")

View File

@ -1,5 +1,6 @@
# v2: select precision for saved checkpoint # v2: select precision for saved checkpoint
# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset) # v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset)
# v4: support SD2.0, add lr scheduler options, supports save_every_n_epochs and save_state for DiffUsers model
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
# License: # License:
@ -44,12 +45,15 @@ import fine_tuning_utils
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14" TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
# checkpointファイル名 # checkpointファイル名
LAST_CHECKPOINT_NAME = "last.ckpt" LAST_CHECKPOINT_NAME = "last.ckpt"
LAST_STATE_NAME = "last-state" LAST_STATE_NAME = "last-state"
LAST_DIFFUSERS_DIR_NAME = "last"
EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt" EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
EPOCH_STATE_NAME = "epoch-{:06d}-state" EPOCH_STATE_NAME = "epoch-{:06d}-state"
EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
def collate_fn(examples): def collate_fn(examples):
@ -63,7 +67,7 @@ class FineTuningDataset(torch.utils.data.Dataset):
self.metadata = metadata self.metadata = metadata
self.train_data_dir = train_data_dir self.train_data_dir = train_data_dir
self.batch_size = batch_size self.batch_size = batch_size
self.tokenizer = tokenizer self.tokenizer: CLIPTokenizer = tokenizer
self.max_token_length = max_token_length self.max_token_length = max_token_length
self.shuffle_caption = shuffle_caption self.shuffle_caption = shuffle_caption
self.debug = debug self.debug = debug
@ -159,17 +163,38 @@ class FineTuningDataset(torch.utils.data.Dataset):
input_ids = self.tokenizer(caption, padding="max_length", truncation=True, input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
if self.tokenizer_max_length > self.tokenizer.model_max_length: if self.tokenizer_max_length > self.tokenizer.model_max_length:
input_ids = input_ids.squeeze(0) input_ids = input_ids.squeeze(0)
iids_list = [] iids_list = []
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
iid = (input_ids[0].unsqueeze(0), # v1
input_ids[i:i + self.tokenizer.model_max_length - 2], # 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
input_ids[-1].unsqueeze(0)) # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
iid = torch.cat(iid) for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
iids_list.append(iid) ids_chunk = (input_ids[0].unsqueeze(0),
input_ids[i:i + self.tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0))
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
ids_chunk = (input_ids[0].unsqueeze(0), # BOS
input_ids[i:i + self.tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0)) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変えるx <EOS> なら結果的に変化なし)
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
ids_chunk[-1] = self.tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == self.tokenizer.pad_token_id:
ids_chunk[1] = self.tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77 input_ids = torch.stack(iids_list) # 3,77
input_ids_list.append(input_ids) input_ids_list.append(input_ids)
@ -192,15 +217,17 @@ def save_hypernetwork(output_file, hypernetwork):
def train(args): def train(args):
fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training
# その他のオプション設定を確認する
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
# モデル形式のオプション設定を確認する # モデル形式のオプション設定を確認する
# v11からDiffUsersから直接落としてくるのもOKただし認証がいるやつは未対応、またv11からDiffUsersも途中保存に対応した
use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
if not use_stable_diffusion_format:
assert os.path.exists(
args.pretrained_model_name_or_path), f"no pretrained model / 学習元モデルがありません : {args.pretrained_model_name_or_path}"
assert not fine_tuning or (
args.save_every_n_epochs is None or use_stable_diffusion_format), "when loading Diffusers model, save_every_n_epochs does not work / Diffusersのモデルを読み込むときにはsave_every_n_epochsオプションは無効になります"
# 乱数系列を初期化する
if args.seed is not None: if args.seed is not None:
set_seed(args.seed) set_seed(args.seed)
@ -215,18 +242,22 @@ def train(args):
# tokenizerを読み込む # tokenizerを読み込む
print("prepare tokenizer") print("prepare tokenizer")
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
if args.max_token_length is not None: if args.max_token_length is not None:
print(f"update token length in tokenizer: {args.max_token_length}") print(f"update token length: {args.max_token_length}")
# datasetを用意する # datasetを用意する
print("prepare dataset") print("prepare dataset")
train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size, train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size,
tokenizer, args.max_token_length, args.shuffle_caption, args.dataset_repeats, args.debug_dataset) tokenizer, args.max_token_length, args.shuffle_caption, args.dataset_repeats, args.debug_dataset)
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
print(f"Total images / 画像数: {train_dataset.images_count}")
if args.debug_dataset: if args.debug_dataset:
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
print(f"Total images / 画像数: {train_dataset.images_count}")
train_dataset.show_buckets() train_dataset.show_buckets()
i = 0 i = 0
for example in train_dataset: for example in train_dataset:
@ -251,14 +282,33 @@ def train(args):
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir) mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
save_dtype = None
if args.save_precision == "fp16":
save_dtype = torch.float16
elif args.save_precision == "bf16":
save_dtype = torch.bfloat16
elif args.save_precision == "float":
save_dtype = torch.float32
# モデルを読み込む # モデルを読み込む
if use_stable_diffusion_format: if use_stable_diffusion_format:
print("load StableDiffusion checkpoint") print("load StableDiffusion checkpoint")
text_encoder, _, unet = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.pretrained_model_name_or_path) text_encoder, _, unet = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(
args.v2, args.pretrained_model_name_or_path)
else: else:
print("load Diffusers pretrained models") print("load Diffusers pretrained models")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる
text_encoder = pipe.text_encoder
unet = pipe.unet
del pipe
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
replace_unet_modules(unet, args.mem_eff_attn, args.xformers) replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@ -279,21 +329,6 @@ def train(args):
print("apply hypernetwork") print("apply hypernetwork")
hypernetwork.apply_to_diffusers(None, text_encoder, unet) hypernetwork.apply_to_diffusers(None, text_encoder, unet)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
save_dtype = None
if args.save_precision == "fp16":
save_dtype = torch.float16
elif args.save_precision == "bf16":
save_dtype = torch.bfloat16
elif args.save_precision == "float":
save_dtype = torch.float32
# 学習を準備する:モデルを適切な状態にする # 学習を準備する:モデルを適切な状態にする
training_models = [] training_models = []
if fine_tuning: if fine_tuning:
@ -351,7 +386,7 @@ def train(args):
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler( lr_scheduler = diffusers.optimization.get_scheduler(
"constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if fine_tuning: if fine_tuning:
@ -384,10 +419,14 @@ def train(args):
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0 global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) # v4で更新clip_sample=Falseに
# Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる
# 既存の1.4/1.5/2.0はすべてschdulerのconfigはクラス名を除いて同じ
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork") accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork")
@ -400,7 +439,7 @@ def train(args):
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(training_models[0]): # ここはこれでいいのか……? with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)
latents = latents * 0.18215 latents = latents * 0.18215
b_size = latents.shape[0] b_size = latents.shape[0]
@ -418,15 +457,29 @@ def train(args):
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if args.max_token_length is not None: if args.max_token_length is not None:
# <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す if args.v2:
sts_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
for i in range(1, args.max_token_length, tokenizer.model_max_length): states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
sts_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) for i in range(1, args.max_token_length, tokenizer.model_max_length):
sts_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
encoder_hidden_states = torch.cat(sts_list, dim=1) if i > 0:
for j in range(len(chunk)):
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
@ -442,7 +495,41 @@ def train(args):
# Predict the noise residual # Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") if args.v_parameterization:
# v-parameterization training
# 11/29現在v predictionのコードがDiffusersにcommitされたがリリースされていないので独自コードを使う
# 実装の中身は同じ模様
# こうしたい:
# target = noise_scheduler.get_v(latents, noise, timesteps)
# StabilityAiのddpm.pyのコード
# elif self.parameterization == "v":
# target = self.get_v(x_start, noise, t)
# ...
# def get_v(self, x, noise, t):
# return (
# extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
# extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
# )
# scheduling_ddim.pyのコード
# elif self.config.prediction_type == "v_prediction":
# pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# # predict V
# model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
# これでいいかな?:
alpha_prod_t = noise_scheduler.alphas_cumprod[timesteps]
beta_prod_t = 1 - alpha_prod_t
alpha_prod_t = torch.reshape(alpha_prod_t, (len(alpha_prod_t), 1, 1, 1)) # broadcastされないらしいのでreshape
beta_prod_t = torch.reshape(beta_prod_t, (len(beta_prod_t), 1, 1, 1))
target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
@ -460,7 +547,7 @@ def train(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
current_loss = loss.detach().item() * b_size current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
@ -481,14 +568,20 @@ def train(args):
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
print("saving check point.") print("saving checkpoint.")
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1)) ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
if fine_tuning: if fine_tuning:
fine_tuning_utils.save_stable_diffusion_checkpoint( if use_stable_diffusion_format:
ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet), fine_tuning_utils.save_stable_diffusion_checkpoint(
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype) args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
else:
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
os.makedirs(out_dir, exist_ok=True)
fine_tuning_utils.save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder),
accelerator.unwrap_model(unet), args.pretrained_model_name_or_path, save_dtype)
else: else:
save_hypernetwork(ckpt_file, accelerator.unwrap_model(hypernetwork)) save_hypernetwork(ckpt_file, accelerator.unwrap_model(hypernetwork))
@ -519,16 +612,14 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME) ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
fine_tuning_utils.save_stable_diffusion_checkpoint( fine_tuning_utils.save_stable_diffusion_checkpoint(
ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype) args.v2, ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
else: else:
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
print(f"save trained model as Diffusers to {args.output_dir}") print(f"save trained model as Diffusers to {args.output_dir}")
pipeline = StableDiffusionPipeline.from_pretrained( out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
args.pretrained_model_name_or_path, os.makedirs(out_dir, exist_ok=True)
unet=unet, fine_tuning_utils.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
text_encoder=text_encoder, args.pretrained_model_name_or_path, save_dtype)
)
pipeline.save_pretrained(args.output_dir)
else: else:
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME) ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
print(f"save trained model to {ckpt_file}") print(f"save trained model to {ckpt_file}")
@ -817,6 +908,10 @@ def replace_unet_cross_attn_to_xformers():
if __name__ == '__main__': if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48) # torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
parser.add_argument("--v_parameterization", action='store_true',
help='enable v-parameterization training / v-parameterization学習を有効にする')
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル") parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル")
@ -832,7 +927,7 @@ if __name__ == '__main__':
parser.add_argument("--hypernetwork_weights", type=str, default=None, parser.add_argument("--hypernetwork_weights", type=str, default=None,
help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重みHypernetworkの追加学習') help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重みHypernetworkの追加学習')
parser.add_argument("--save_every_n_epochs", type=int, default=None, parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存するStableDiffusion形式のモデルを読み込んだ場合のみ有効") help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_state", action="store_true", parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, parser.add_argument("--resume", type=str, default=None,
@ -857,13 +952,17 @@ if __name__ == '__main__':
parser.add_argument("--mixed_precision", type=str, default="no", parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--save_precision", type=str, default=None, parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存するStableDiffusion形式での保存時のみ有効")
parser.add_argument("--clip_skip", type=int, default=None, parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上") help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--debug_dataset", action="store_true", parser.add_argument("--debug_dataset", action="store_true",
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
parser.add_argument("--logging_dir", type=str, default=None, parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)

View File

@ -2,11 +2,13 @@ import math
import torch import torch
from transformers import CLIPTextModel from transformers import CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
# StableDiffusionのモデルパラメータ # region checkpoint変換、読み込み、書き込み ###############################
# DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000 NUM_TRAIN_TIMESTEPS = 1000
BETA_START = 0.00085 BETA_START = 0.00085
BETA_END = 0.0120 BETA_END = 0.0120
@ -29,13 +31,15 @@ VAE_PARAMS_CH = 128
VAE_PARAMS_CH_MULT = [1, 2, 4, 4] VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
VAE_PARAMS_NUM_RES_BLOCKS = 2 VAE_PARAMS_NUM_RES_BLOCKS = 2
# V2
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
V2_UNET_PARAMS_CONTEXT_DIM = 1024
# region conversion
# checkpoint変換など ###############################
# region StableDiffusion->Diffusersの変換コード # region StableDiffusion->Diffusersの変換コード
# convert_original_stable_diffusion_to_diffusers をコピーしているASL 2.0 # convert_original_stable_diffusion_to_diffusers をコピーしているASL 2.0
def shave_segments(path, n_shave_prefix_segments=1): def shave_segments(path, n_shave_prefix_segments=1):
""" """
Removes segments. Positive values shave the first segments, negative shave the last segments. Removes segments. Positive values shave the first segments, negative shave the last segments.
@ -199,7 +203,16 @@ def conv_attn_to_linear(checkpoint):
checkpoint[key] = checkpoint[key][:, :, 0] checkpoint[key] = checkpoint[key][:, :, 0]
def convert_ldm_unet_checkpoint(checkpoint, config): def linear_transformer_to_conv(checkpoint):
keys = list(checkpoint.keys())
tf_keys = ["proj_in.weight", "proj_out.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in tf_keys:
if checkpoint[key].ndim == 2:
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
""" """
Takes a state dict and a config, and returns a converted checkpoint. Takes a state dict and a config, and returns a converted checkpoint.
""" """
@ -349,6 +362,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config):
new_checkpoint[new_path] = unet_state_dict[old_path] new_checkpoint[new_path] = unet_state_dict[old_path]
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
if v2:
linear_transformer_to_conv(new_checkpoint)
return new_checkpoint return new_checkpoint
@ -459,7 +476,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint return new_checkpoint
def create_unet_diffusers_config(): def create_unet_diffusers_config(v2):
""" """
Creates a config for the diffusers based on the config of the LDM model. Creates a config for the diffusers based on the config of the LDM model.
""" """
@ -489,8 +506,8 @@ def create_unet_diffusers_config():
up_block_types=tuple(up_block_types), up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels), block_out_channels=tuple(block_out_channels),
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM, cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
attention_head_dim=UNET_PARAMS_NUM_HEADS, attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
) )
return config return config
@ -519,20 +536,82 @@ def create_vae_diffusers_config():
return config return config
def convert_ldm_clip_checkpoint(checkpoint): def convert_ldm_clip_checkpoint_v1(checkpoint):
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
text_model_dict = {} text_model_dict = {}
for key in keys: for key in keys:
if key.startswith("cond_stage_model.transformer"): if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
return text_model_dict
text_model.load_state_dict(text_model_dict)
return text_model def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
# 嫌になるくらい違うぞ!
def convert_key(key):
if not key.startswith("cond_stage_model"):
return None
# common conversion
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
key = key.replace("cond_stage_model.model.", "text_model.")
if "resblocks" in key:
# resblocks conversion
key = key.replace(".resblocks.", ".layers.")
if ".ln_" in key:
key = key.replace(".ln_", ".layer_norm")
elif ".mlp." in key:
key = key.replace(".c_fc.", ".fc1.")
key = key.replace(".c_proj.", ".fc2.")
elif '.attn.out_proj' in key:
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
elif '.attn.in_proj' in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f"unexpected key in SD: {key}")
elif '.positional_embedding' in key:
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
elif '.text_projection' in key:
key = None # 使われない???
elif '.logit_scale' in key:
key = None # 使われない???
elif '.token_embedding' in key:
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
elif '.ln_final' in key:
key = key.replace(".ln_final", ".final_layer_norm")
return key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
# remove resblocks 23
if '.resblocks.23.' in key:
continue
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
# attnの変換
for key in keys:
if '.resblocks.23.' in key:
continue
if '.resblocks' in key and '.attn.in_proj_' in key:
# 三つに分割
values = torch.chunk(checkpoint[key], 3)
key_suffix = ".weight" if "weight" in key else ".bias"
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
key_pfx = key_pfx.replace("_weight", "")
key_pfx = key_pfx.replace("_bias", "")
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# position_idsの追加
new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
return new_sd
# endregion # endregion
@ -540,7 +619,16 @@ def convert_ldm_clip_checkpoint(checkpoint):
# region Diffusers->StableDiffusion の変換コード # region Diffusers->StableDiffusion の変換コード
# convert_diffusers_to_original_stable_diffusion をコピーしているASL 2.0 # convert_diffusers_to_original_stable_diffusion をコピーしているASL 2.0
def convert_unet_state_dict(unet_state_dict): def conv_transformer_to_linear(checkpoint):
keys = list(checkpoint.keys())
tf_keys = ["proj_in.weight", "proj_out.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in tf_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
unet_conversion_map = [ unet_conversion_map = [
# (stable-diffusion, HF Diffusers) # (stable-diffusion, HF Diffusers)
("time_embed.0.weight", "time_embedding.linear_1.weight"), ("time_embed.0.weight", "time_embedding.linear_1.weight"),
@ -629,12 +717,16 @@ def convert_unet_state_dict(unet_state_dict):
v = v.replace(hf_part, sd_part) v = v.replace(hf_part, sd_part)
mapping[k] = v mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
if v2:
conv_transformer_to_linear(new_state_dict)
return new_state_dict return new_state_dict
# endregion # endregion
def load_checkpoint_with_conversion(ckpt_path): def load_checkpoint_with_text_encoder_conversion(ckpt_path):
# text encoderの格納形式が違うモデルに対応する ('text_model'がない) # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
TEXT_ENCODER_KEY_REPLACEMENTS = [ TEXT_ENCODER_KEY_REPLACEMENTS = [
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
@ -659,52 +751,148 @@ def load_checkpoint_with_conversion(ckpt_path):
return checkpoint return checkpoint
def load_models_from_stable_diffusion_checkpoint(ckpt_path): def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
checkpoint = load_checkpoint_with_conversion(ckpt_path) checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
state_dict = checkpoint["state_dict"] state_dict = checkpoint["state_dict"]
if dtype is not None:
for k, v in state_dict.items():
if type(v) is torch.Tensor:
state_dict[k] = v.to(dtype)
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config() unet_config = create_unet_diffusers_config(v2)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config) converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
unet = UNet2DConditionModel(**unet_config) unet = UNet2DConditionModel(**unet_config)
unet.load_state_dict(converted_unet_checkpoint) info = unet.load_state_dict(converted_unet_checkpoint)
print("loading u-net:", info)
# Convert the VAE model. # Convert the VAE model.
vae_config = create_vae_diffusers_config() vae_config = create_vae_diffusers_config()
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) info = vae.load_state_dict(converted_vae_checkpoint)
print("loadint vae:", info)
# convert text_model # convert text_model
text_model = convert_ldm_clip_checkpoint(state_dict) if v2:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=23,
num_attention_heads=16,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=512,
torch_dtype="float32",
transformers_version="4.25.0.dev0",
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
else:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print("loading text encoder:", info)
return text_model, vae, unet return text_model, vae, unet
def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None): def convert_text_encoder_state_dict_to_sd_v2(checkpoint):
def convert_key(key):
# position_idsの除去
if ".position_ids" in key:
return None
# common
key = key.replace("text_model.encoder.", "transformer.")
key = key.replace("text_model.", "")
if "layers" in key:
# resblocks conversion
key = key.replace(".layers.", ".resblocks.")
if ".layer_norm" in key:
key = key.replace(".layer_norm", ".ln_")
elif ".mlp." in key:
key = key.replace(".fc1.", ".c_fc.")
key = key.replace(".fc2.", ".c_proj.")
elif '.self_attn.out_proj' in key:
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
elif '.self_attn.' in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f"unexpected key in DiffUsers model: {key}")
elif '.position_embedding' in key:
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
elif '.token_embedding' in key:
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
elif 'final_layer_norm' in key:
key = key.replace("final_layer_norm", "ln_final")
return key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
# attnの変換
for key in keys:
if 'layers' in key and 'q_proj' in key:
# 三つを結合
key_q = key
key_k = key.replace("q_proj", "k_proj")
key_v = key.replace("q_proj", "v_proj")
value_q = checkpoint[key_q]
value_k = checkpoint[key_k]
value_v = checkpoint[key_v]
value = torch.cat([value_q, value_k, value_v])
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
new_sd[new_key] = value
return new_sd
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
# VAEがメモリ上にないので、もう一度VAEを含めて読み込む # VAEがメモリ上にないので、もう一度VAEを含めて読み込む
checkpoint = load_checkpoint_with_conversion(ckpt_path) checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
state_dict = checkpoint["state_dict"] state_dict = checkpoint["state_dict"]
def assign_new_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
assert key in state_dict, f"Illegal key in save SD: {key}"
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
# Convert the UNet model # Convert the UNet model
unet_state_dict = convert_unet_state_dict(unet.state_dict()) unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
for k, v in unet_state_dict.items(): assign_new_sd("model.diffusion_model.", unet_state_dict)
key = "model.diffusion_model." + k
assert key in state_dict, f"Illegal key in save SD: {key}"
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
# Convert the text encoder model # Convert the text encoder model
text_enc_dict = text_encoder.state_dict() # 変換不要 if v2:
for k, v in text_enc_dict.items(): text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict())
key = "cond_stage_model.transformer." + k assign_new_sd("cond_stage_model.model.", text_enc_dict)
assert key in state_dict, f"Illegal key in save SD: {key}" else:
if save_dtype is not None: text_enc_dict = text_encoder.state_dict()
v = v.detach().clone().to("cpu").to(save_dtype) assign_new_sd("cond_stage_model.transformer.", text_enc_dict)
state_dict[key] = v
# Put together new checkpoint # Put together new checkpoint
new_ckpt = {'state_dict': state_dict} new_ckpt = {'state_dict': state_dict}
@ -718,6 +906,21 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
new_ckpt['global_step'] = steps new_ckpt['global_step'] = steps
torch.save(new_ckpt, output_file) torch.save(new_ckpt, output_file)
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, save_dtype):
pipeline = StableDiffusionPipeline(
unet=unet,
text_encoder=text_encoder,
vae=AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae"),
scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
safety_checker=None,
feature_extractor=None,
requires_safety_checker=None,
)
pipeline.save_pretrained(output_dir)
# endregion # endregion

View File

@ -10,7 +10,7 @@ from tqdm import tqdm
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if args.in_json is not None: if args.in_json is not None:

View File

@ -10,7 +10,7 @@ from tqdm import tqdm
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if args.in_json is not None: if args.in_json is not None:

View File

@ -36,7 +36,7 @@ def get_latents(vae, images, weight_dtype):
def main(args): def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json): if os.path.exists(args.in_json):
@ -49,13 +49,11 @@ def main(args):
# モデル形式のオプション設定を確認する # モデル形式のオプション設定を確認する
use_stable_diffusion_format = os.path.isfile(args.model_name_or_path) use_stable_diffusion_format = os.path.isfile(args.model_name_or_path)
if not use_stable_diffusion_format:
assert os.path.exists(args.model_name_or_path), f"no model / モデルがありません : {args.model_name_or_path}"
# モデルを読み込む # モデルを読み込む
if use_stable_diffusion_format: if use_stable_diffusion_format:
print("load StableDiffusion checkpoint") print("load StableDiffusion checkpoint")
_, vae, _ = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.model_name_or_path) _, vae, _ = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_name_or_path)
else: else:
print("load Diffusers pretrained models") print("load Diffusers pretrained models")
vae = AutoencoderKL.from_pretrained(args.model_name_or_path, subfolder="vae") vae = AutoencoderKL.from_pretrained(args.model_name_or_path, subfolder="vae")
@ -73,7 +71,8 @@ def main(args):
max_reso = tuple([int(t) for t in args.max_resolution.split(',')]) max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_resos, bucket_aspect_ratios = fine_tuning_utils.make_bucket_resolutions(max_reso) bucket_resos, bucket_aspect_ratios = fine_tuning_utils.make_bucket_resolutions(
max_reso, args.min_bucket_reso, args.max_bucket_reso)
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
bucket_aspect_ratios = np.array(bucket_aspect_ratios) bucket_aspect_ratios = np.array(bucket_aspect_ratios)
@ -162,9 +161,13 @@ if __name__ == '__main__':
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_resolution", type=str, default="512,512", parser.add_argument("--max_resolution", type=str, default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)") help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
parser.add_argument("--mixed_precision", type=str, default="no", parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")