Update finetune_py to V4
This commit is contained in:
parent
621dabcadf
commit
95a694a3fd
@ -71,6 +71,7 @@ def clean_caption(caption):
|
|||||||
replaced = bef != caption
|
replaced = bef != caption
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
@ -95,16 +96,16 @@ def main(args):
|
|||||||
return
|
return
|
||||||
|
|
||||||
tags = metadata[image_key].get('tags')
|
tags = metadata[image_key].get('tags')
|
||||||
caption = metadata[image_key].get('caption')
|
|
||||||
if tags is None:
|
if tags is None:
|
||||||
print(f"image does not have tags / メタデータにタグがありません: {image_path}")
|
print(f"image does not have tags / メタデータにタグがありません: {image_path}")
|
||||||
return
|
else:
|
||||||
|
metadata[image_key]['tags'] = clean_tags(image_key, tags)
|
||||||
|
|
||||||
|
caption = metadata[image_key].get('caption')
|
||||||
if caption is None:
|
if caption is None:
|
||||||
print(f"image does not have caption / メタデータにキャプションがありません: {image_path}")
|
print(f"image does not have caption / メタデータにキャプションがありません: {image_path}")
|
||||||
return
|
else:
|
||||||
|
metadata[image_key]['caption'] = clean_caption(caption)
|
||||||
metadata[image_key]['tags'] = clean_tags(image_key, tags)
|
|
||||||
metadata[image_key]['caption'] = clean_caption(caption)
|
|
||||||
|
|
||||||
# metadataを書き出して終わり
|
# metadataを書き出して終わり
|
||||||
print(f"writing metadata: {args.out_json}")
|
print(f"writing metadata: {args.out_json}")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# v2: select precision for saved checkpoint
|
# v2: select precision for saved checkpoint
|
||||||
# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset)
|
# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset)
|
||||||
|
# v4: support SD2.0, add lr scheduler options, supports save_every_n_epochs and save_state for DiffUsers model
|
||||||
|
|
||||||
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
||||||
# License:
|
# License:
|
||||||
@ -44,12 +45,15 @@ import fine_tuning_utils
|
|||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
|
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||||||
|
|
||||||
# checkpointファイル名
|
# checkpointファイル名
|
||||||
LAST_CHECKPOINT_NAME = "last.ckpt"
|
LAST_CHECKPOINT_NAME = "last.ckpt"
|
||||||
LAST_STATE_NAME = "last-state"
|
LAST_STATE_NAME = "last-state"
|
||||||
|
LAST_DIFFUSERS_DIR_NAME = "last"
|
||||||
EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
|
EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
|
||||||
EPOCH_STATE_NAME = "epoch-{:06d}-state"
|
EPOCH_STATE_NAME = "epoch-{:06d}-state"
|
||||||
|
EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
@ -63,7 +67,7 @@ class FineTuningDataset(torch.utils.data.Dataset):
|
|||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.train_data_dir = train_data_dir
|
self.train_data_dir = train_data_dir
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer: CLIPTokenizer = tokenizer
|
||||||
self.max_token_length = max_token_length
|
self.max_token_length = max_token_length
|
||||||
self.shuffle_caption = shuffle_caption
|
self.shuffle_caption = shuffle_caption
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
@ -159,17 +163,38 @@ class FineTuningDataset(torch.utils.data.Dataset):
|
|||||||
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
|
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
|
||||||
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
|
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
|
||||||
|
|
||||||
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
|
||||||
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
|
||||||
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
||||||
input_ids = input_ids.squeeze(0)
|
input_ids = input_ids.squeeze(0)
|
||||||
iids_list = []
|
iids_list = []
|
||||||
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
|
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
||||||
iid = (input_ids[0].unsqueeze(0),
|
# v1
|
||||||
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||||||
input_ids[-1].unsqueeze(0))
|
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||||||
iid = torch.cat(iid)
|
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
|
||||||
iids_list.append(iid)
|
ids_chunk = (input_ids[0].unsqueeze(0),
|
||||||
|
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
||||||
|
input_ids[-1].unsqueeze(0))
|
||||||
|
ids_chunk = torch.cat(ids_chunk)
|
||||||
|
iids_list.append(ids_chunk)
|
||||||
|
else:
|
||||||
|
# v2
|
||||||
|
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
||||||
|
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
|
||||||
|
ids_chunk = (input_ids[0].unsqueeze(0), # BOS
|
||||||
|
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
||||||
|
input_ids[-1].unsqueeze(0)) # PAD or EOS
|
||||||
|
ids_chunk = torch.cat(ids_chunk)
|
||||||
|
|
||||||
|
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
||||||
|
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
||||||
|
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
|
||||||
|
ids_chunk[-1] = self.tokenizer.eos_token_id
|
||||||
|
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
||||||
|
if ids_chunk[1] == self.tokenizer.pad_token_id:
|
||||||
|
ids_chunk[1] = self.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
iids_list.append(ids_chunk)
|
||||||
|
|
||||||
input_ids = torch.stack(iids_list) # 3,77
|
input_ids = torch.stack(iids_list) # 3,77
|
||||||
|
|
||||||
input_ids_list.append(input_ids)
|
input_ids_list.append(input_ids)
|
||||||
@ -192,15 +217,17 @@ def save_hypernetwork(output_file, hypernetwork):
|
|||||||
def train(args):
|
def train(args):
|
||||||
fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training
|
fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training
|
||||||
|
|
||||||
|
# その他のオプション設定を確認する
|
||||||
|
if args.v_parameterization and not args.v2:
|
||||||
|
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
||||||
|
if args.v2 and args.clip_skip is not None:
|
||||||
|
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||||
|
|
||||||
# モデル形式のオプション設定を確認する
|
# モデル形式のオプション設定を確認する
|
||||||
|
# v11からDiffUsersから直接落としてくるのもOK(ただし認証がいるやつは未対応)、またv11からDiffUsersも途中保存に対応した
|
||||||
use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
|
use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
|
||||||
if not use_stable_diffusion_format:
|
|
||||||
assert os.path.exists(
|
|
||||||
args.pretrained_model_name_or_path), f"no pretrained model / 学習元モデルがありません : {args.pretrained_model_name_or_path}"
|
|
||||||
|
|
||||||
assert not fine_tuning or (
|
|
||||||
args.save_every_n_epochs is None or use_stable_diffusion_format), "when loading Diffusers model, save_every_n_epochs does not work / Diffusersのモデルを読み込むときにはsave_every_n_epochsオプションは無効になります"
|
|
||||||
|
|
||||||
|
# 乱数系列を初期化する
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
@ -215,18 +242,22 @@ def train(args):
|
|||||||
|
|
||||||
# tokenizerを読み込む
|
# tokenizerを読み込む
|
||||||
print("prepare tokenizer")
|
print("prepare tokenizer")
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
if args.v2:
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
||||||
|
else:
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
||||||
|
|
||||||
if args.max_token_length is not None:
|
if args.max_token_length is not None:
|
||||||
print(f"update token length in tokenizer: {args.max_token_length}")
|
print(f"update token length: {args.max_token_length}")
|
||||||
|
|
||||||
# datasetを用意する
|
# datasetを用意する
|
||||||
print("prepare dataset")
|
print("prepare dataset")
|
||||||
train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size,
|
train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size,
|
||||||
tokenizer, args.max_token_length, args.shuffle_caption, args.dataset_repeats, args.debug_dataset)
|
tokenizer, args.max_token_length, args.shuffle_caption, args.dataset_repeats, args.debug_dataset)
|
||||||
|
|
||||||
|
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
|
||||||
|
print(f"Total images / 画像数: {train_dataset.images_count}")
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
|
|
||||||
print(f"Total images / 画像数: {train_dataset.images_count}")
|
|
||||||
train_dataset.show_buckets()
|
train_dataset.show_buckets()
|
||||||
i = 0
|
i = 0
|
||||||
for example in train_dataset:
|
for example in train_dataset:
|
||||||
@ -251,14 +282,33 @@ def train(args):
|
|||||||
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
|
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir)
|
mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir)
|
||||||
|
|
||||||
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
|
weight_dtype = torch.float32
|
||||||
|
if args.mixed_precision == "fp16":
|
||||||
|
weight_dtype = torch.float16
|
||||||
|
elif args.mixed_precision == "bf16":
|
||||||
|
weight_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
save_dtype = None
|
||||||
|
if args.save_precision == "fp16":
|
||||||
|
save_dtype = torch.float16
|
||||||
|
elif args.save_precision == "bf16":
|
||||||
|
save_dtype = torch.bfloat16
|
||||||
|
elif args.save_precision == "float":
|
||||||
|
save_dtype = torch.float32
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
print("load StableDiffusion checkpoint")
|
print("load StableDiffusion checkpoint")
|
||||||
text_encoder, _, unet = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.pretrained_model_name_or_path)
|
text_encoder, _, unet = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(
|
||||||
|
args.v2, args.pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
print("load Diffusers pretrained models")
|
print("load Diffusers pretrained models")
|
||||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
||||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
# , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる
|
||||||
|
text_encoder = pipe.text_encoder
|
||||||
|
unet = pipe.unet
|
||||||
|
del pipe
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
@ -279,21 +329,6 @@ def train(args):
|
|||||||
print("apply hypernetwork")
|
print("apply hypernetwork")
|
||||||
hypernetwork.apply_to_diffusers(None, text_encoder, unet)
|
hypernetwork.apply_to_diffusers(None, text_encoder, unet)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
|
||||||
weight_dtype = torch.float32
|
|
||||||
if args.mixed_precision == "fp16":
|
|
||||||
weight_dtype = torch.float16
|
|
||||||
elif args.mixed_precision == "bf16":
|
|
||||||
weight_dtype = torch.bfloat16
|
|
||||||
|
|
||||||
save_dtype = None
|
|
||||||
if args.save_precision == "fp16":
|
|
||||||
save_dtype = torch.float16
|
|
||||||
elif args.save_precision == "bf16":
|
|
||||||
save_dtype = torch.bfloat16
|
|
||||||
elif args.save_precision == "float":
|
|
||||||
save_dtype = torch.float32
|
|
||||||
|
|
||||||
# 学習を準備する:モデルを適切な状態にする
|
# 学習を準備する:モデルを適切な状態にする
|
||||||
training_models = []
|
training_models = []
|
||||||
if fine_tuning:
|
if fine_tuning:
|
||||||
@ -351,7 +386,7 @@ def train(args):
|
|||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
"constant", optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if fine_tuning:
|
if fine_tuning:
|
||||||
@ -384,10 +419,14 @@ def train(args):
|
|||||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps")
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
# v4で更新:clip_sample=Falseに
|
||||||
|
# Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる
|
||||||
|
# 既存の1.4/1.5/2.0はすべてschdulerのconfigは(クラス名を除いて)同じ
|
||||||
|
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
||||||
|
num_train_timesteps=1000, clip_sample=False)
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork")
|
accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork")
|
||||||
@ -400,7 +439,7 @@ def train(args):
|
|||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
with accelerator.accumulate(training_models[0]): # ここはこれでいいのか……?
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
@ -418,15 +457,29 @@ def train(args):
|
|||||||
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
||||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||||
|
|
||||||
|
# bs*3, 77, 768 or 1024
|
||||||
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
||||||
|
|
||||||
if args.max_token_length is not None:
|
if args.max_token_length is not None:
|
||||||
# <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
if args.v2:
|
||||||
sts_list = [encoder_hidden_states[:, 0].unsqueeze(1)]
|
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||||
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||||
sts_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2])
|
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
||||||
sts_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
|
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||||
encoder_hidden_states = torch.cat(sts_list, dim=1)
|
if i > 0:
|
||||||
|
for j in range(len(chunk)):
|
||||||
|
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||||
|
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||||
|
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||||
|
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||||
|
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||||
|
else:
|
||||||
|
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||||
|
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||||
|
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
||||||
|
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||||
|
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||||
|
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||||
|
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
@ -442,7 +495,41 @@ def train(args):
|
|||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
if args.v_parameterization:
|
||||||
|
# v-parameterization training
|
||||||
|
|
||||||
|
# 11/29現在v predictionのコードがDiffusersにcommitされたがリリースされていないので独自コードを使う
|
||||||
|
# 実装の中身は同じ模様
|
||||||
|
|
||||||
|
# こうしたい:
|
||||||
|
# target = noise_scheduler.get_v(latents, noise, timesteps)
|
||||||
|
|
||||||
|
# StabilityAiのddpm.pyのコード:
|
||||||
|
# elif self.parameterization == "v":
|
||||||
|
# target = self.get_v(x_start, noise, t)
|
||||||
|
# ...
|
||||||
|
# def get_v(self, x, noise, t):
|
||||||
|
# return (
|
||||||
|
# extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
|
||||||
|
# extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
||||||
|
# )
|
||||||
|
|
||||||
|
# scheduling_ddim.pyのコード:
|
||||||
|
# elif self.config.prediction_type == "v_prediction":
|
||||||
|
# pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||||
|
# # predict V
|
||||||
|
# model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||||
|
|
||||||
|
# これでいいかな?:
|
||||||
|
alpha_prod_t = noise_scheduler.alphas_cumprod[timesteps]
|
||||||
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
|
alpha_prod_t = torch.reshape(alpha_prod_t, (len(alpha_prod_t), 1, 1, 1)) # broadcastされないらしいのでreshape
|
||||||
|
beta_prod_t = torch.reshape(beta_prod_t, (len(beta_prod_t), 1, 1, 1))
|
||||||
|
target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents
|
||||||
|
else:
|
||||||
|
target = noise
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
@ -460,7 +547,7 @@ def train(args):
|
|||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
current_loss = loss.detach().item() * b_size
|
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
@ -481,14 +568,20 @@ def train(args):
|
|||||||
|
|
||||||
if args.save_every_n_epochs is not None:
|
if args.save_every_n_epochs is not None:
|
||||||
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
|
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
|
||||||
print("saving check point.")
|
print("saving checkpoint.")
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
|
ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
|
||||||
|
|
||||||
if fine_tuning:
|
if fine_tuning:
|
||||||
fine_tuning_utils.save_stable_diffusion_checkpoint(
|
if use_stable_diffusion_format:
|
||||||
ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
|
fine_tuning_utils.save_stable_diffusion_checkpoint(
|
||||||
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
|
args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
|
||||||
|
args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
|
||||||
|
else:
|
||||||
|
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
fine_tuning_utils.save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder),
|
||||||
|
accelerator.unwrap_model(unet), args.pretrained_model_name_or_path, save_dtype)
|
||||||
else:
|
else:
|
||||||
save_hypernetwork(ckpt_file, accelerator.unwrap_model(hypernetwork))
|
save_hypernetwork(ckpt_file, accelerator.unwrap_model(hypernetwork))
|
||||||
|
|
||||||
@ -519,16 +612,14 @@ def train(args):
|
|||||||
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
|
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
|
||||||
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
||||||
fine_tuning_utils.save_stable_diffusion_checkpoint(
|
fine_tuning_utils.save_stable_diffusion_checkpoint(
|
||||||
ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
|
args.v2, ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
|
||||||
else:
|
else:
|
||||||
# Create the pipeline using using the trained modules and save it.
|
# Create the pipeline using using the trained modules and save it.
|
||||||
print(f"save trained model as Diffusers to {args.output_dir}")
|
print(f"save trained model as Diffusers to {args.output_dir}")
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
|
||||||
args.pretrained_model_name_or_path,
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
unet=unet,
|
fine_tuning_utils.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
||||||
text_encoder=text_encoder,
|
args.pretrained_model_name_or_path, save_dtype)
|
||||||
)
|
|
||||||
pipeline.save_pretrained(args.output_dir)
|
|
||||||
else:
|
else:
|
||||||
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
|
ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
@ -817,6 +908,10 @@ def replace_unet_cross_attn_to_xformers():
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# torch.cuda.set_per_process_memory_fraction(0.48)
|
# torch.cuda.set_per_process_memory_fraction(0.48)
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--v2", action='store_true',
|
||||||
|
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
||||||
|
parser.add_argument("--v_parameterization", action='store_true',
|
||||||
|
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
||||||
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
||||||
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
||||||
parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル")
|
parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル")
|
||||||
@ -832,7 +927,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--hypernetwork_weights", type=str, default=None,
|
parser.add_argument("--hypernetwork_weights", type=str, default=None,
|
||||||
help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重み(Hypernetworkの追加学習)')
|
help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重み(Hypernetworkの追加学習)')
|
||||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||||
help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存する(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
|
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||||
parser.add_argument("--save_state", action="store_true",
|
parser.add_argument("--save_state", action="store_true",
|
||||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||||
parser.add_argument("--resume", type=str, default=None,
|
parser.add_argument("--resume", type=str, default=None,
|
||||||
@ -857,13 +952,17 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||||
parser.add_argument("--save_precision", type=str, default=None,
|
parser.add_argument("--save_precision", type=str, default=None,
|
||||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)")
|
||||||
parser.add_argument("--clip_skip", type=int, default=None,
|
parser.add_argument("--clip_skip", type=int, default=None,
|
||||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||||
parser.add_argument("--debug_dataset", action="store_true",
|
parser.add_argument("--debug_dataset", action="store_true",
|
||||||
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
|
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
|
||||||
parser.add_argument("--logging_dir", type=str, default=None,
|
parser.add_argument("--logging_dir", type=str, default=None,
|
||||||
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
||||||
|
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
||||||
|
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
||||||
|
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||||
|
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
@ -2,11 +2,13 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
||||||
|
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
|
||||||
|
|
||||||
# StableDiffusionのモデルパラメータ
|
# region checkpoint変換、読み込み、書き込み ###############################
|
||||||
|
|
||||||
|
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||||
NUM_TRAIN_TIMESTEPS = 1000
|
NUM_TRAIN_TIMESTEPS = 1000
|
||||||
BETA_START = 0.00085
|
BETA_START = 0.00085
|
||||||
BETA_END = 0.0120
|
BETA_END = 0.0120
|
||||||
@ -29,13 +31,15 @@ VAE_PARAMS_CH = 128
|
|||||||
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
||||||
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
||||||
|
|
||||||
|
# V2
|
||||||
|
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
||||||
|
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
||||||
|
|
||||||
# region conversion
|
|
||||||
# checkpoint変換など ###############################
|
|
||||||
|
|
||||||
# region StableDiffusion->Diffusersの変換コード
|
# region StableDiffusion->Diffusersの変換コード
|
||||||
# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
|
# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
|
||||||
|
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
"""
|
"""
|
||||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||||
@ -199,7 +203,16 @@ def conv_attn_to_linear(checkpoint):
|
|||||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_unet_checkpoint(checkpoint, config):
|
def linear_transformer_to_conv(checkpoint):
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
||||||
|
for key in keys:
|
||||||
|
if ".".join(key.split(".")[-2:]) in tf_keys:
|
||||||
|
if checkpoint[key].ndim == 2:
|
||||||
|
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||||
"""
|
"""
|
||||||
Takes a state dict and a config, and returns a converted checkpoint.
|
Takes a state dict and a config, and returns a converted checkpoint.
|
||||||
"""
|
"""
|
||||||
@ -349,6 +362,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config):
|
|||||||
|
|
||||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||||
|
|
||||||
|
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
||||||
|
if v2:
|
||||||
|
linear_transformer_to_conv(new_checkpoint)
|
||||||
|
|
||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
@ -459,7 +476,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def create_unet_diffusers_config():
|
def create_unet_diffusers_config(v2):
|
||||||
"""
|
"""
|
||||||
Creates a config for the diffusers based on the config of the LDM model.
|
Creates a config for the diffusers based on the config of the LDM model.
|
||||||
"""
|
"""
|
||||||
@ -489,8 +506,8 @@ def create_unet_diffusers_config():
|
|||||||
up_block_types=tuple(up_block_types),
|
up_block_types=tuple(up_block_types),
|
||||||
block_out_channels=tuple(block_out_channels),
|
block_out_channels=tuple(block_out_channels),
|
||||||
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
||||||
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM,
|
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
||||||
attention_head_dim=UNET_PARAMS_NUM_HEADS,
|
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
@ -519,20 +536,82 @@ def create_vae_diffusers_config():
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_clip_checkpoint(checkpoint):
|
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
||||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
|
||||||
|
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
text_model_dict = {}
|
text_model_dict = {}
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("cond_stage_model.transformer"):
|
if key.startswith("cond_stage_model.transformer"):
|
||||||
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
||||||
|
return text_model_dict
|
||||||
|
|
||||||
text_model.load_state_dict(text_model_dict)
|
|
||||||
|
|
||||||
return text_model
|
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||||
|
# 嫌になるくらい違うぞ!
|
||||||
|
def convert_key(key):
|
||||||
|
if not key.startswith("cond_stage_model"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# common conversion
|
||||||
|
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
||||||
|
key = key.replace("cond_stage_model.model.", "text_model.")
|
||||||
|
|
||||||
|
if "resblocks" in key:
|
||||||
|
# resblocks conversion
|
||||||
|
key = key.replace(".resblocks.", ".layers.")
|
||||||
|
if ".ln_" in key:
|
||||||
|
key = key.replace(".ln_", ".layer_norm")
|
||||||
|
elif ".mlp." in key:
|
||||||
|
key = key.replace(".c_fc.", ".fc1.")
|
||||||
|
key = key.replace(".c_proj.", ".fc2.")
|
||||||
|
elif '.attn.out_proj' in key:
|
||||||
|
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
||||||
|
elif '.attn.in_proj' in key:
|
||||||
|
key = None # 特殊なので後で処理する
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unexpected key in SD: {key}")
|
||||||
|
elif '.positional_embedding' in key:
|
||||||
|
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
||||||
|
elif '.text_projection' in key:
|
||||||
|
key = None # 使われない???
|
||||||
|
elif '.logit_scale' in key:
|
||||||
|
key = None # 使われない???
|
||||||
|
elif '.token_embedding' in key:
|
||||||
|
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
||||||
|
elif '.ln_final' in key:
|
||||||
|
key = key.replace(".ln_final", ".final_layer_norm")
|
||||||
|
return key
|
||||||
|
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
new_sd = {}
|
||||||
|
for key in keys:
|
||||||
|
# remove resblocks 23
|
||||||
|
if '.resblocks.23.' in key:
|
||||||
|
continue
|
||||||
|
new_key = convert_key(key)
|
||||||
|
if new_key is None:
|
||||||
|
continue
|
||||||
|
new_sd[new_key] = checkpoint[key]
|
||||||
|
|
||||||
|
# attnの変換
|
||||||
|
for key in keys:
|
||||||
|
if '.resblocks.23.' in key:
|
||||||
|
continue
|
||||||
|
if '.resblocks' in key and '.attn.in_proj_' in key:
|
||||||
|
# 三つに分割
|
||||||
|
values = torch.chunk(checkpoint[key], 3)
|
||||||
|
|
||||||
|
key_suffix = ".weight" if "weight" in key else ".bias"
|
||||||
|
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
||||||
|
key_pfx = key_pfx.replace("_weight", "")
|
||||||
|
key_pfx = key_pfx.replace("_bias", "")
|
||||||
|
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
||||||
|
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
||||||
|
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||||
|
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||||
|
|
||||||
|
# position_idsの追加
|
||||||
|
new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||||
|
return new_sd
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
@ -540,7 +619,16 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
|||||||
# region Diffusers->StableDiffusion の変換コード
|
# region Diffusers->StableDiffusion の変換コード
|
||||||
# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
|
# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
|
||||||
|
|
||||||
def convert_unet_state_dict(unet_state_dict):
|
def conv_transformer_to_linear(checkpoint):
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
||||||
|
for key in keys:
|
||||||
|
if ".".join(key.split(".")[-2:]) in tf_keys:
|
||||||
|
if checkpoint[key].ndim > 2:
|
||||||
|
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||||
unet_conversion_map = [
|
unet_conversion_map = [
|
||||||
# (stable-diffusion, HF Diffusers)
|
# (stable-diffusion, HF Diffusers)
|
||||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||||
@ -629,12 +717,16 @@ def convert_unet_state_dict(unet_state_dict):
|
|||||||
v = v.replace(hf_part, sd_part)
|
v = v.replace(hf_part, sd_part)
|
||||||
mapping[k] = v
|
mapping[k] = v
|
||||||
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
||||||
|
|
||||||
|
if v2:
|
||||||
|
conv_transformer_to_linear(new_state_dict)
|
||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_with_conversion(ckpt_path):
|
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||||
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
||||||
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
||||||
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
||||||
@ -659,52 +751,148 @@ def load_checkpoint_with_conversion(ckpt_path):
|
|||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
def load_models_from_stable_diffusion_checkpoint(ckpt_path):
|
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||||
checkpoint = load_checkpoint_with_conversion(ckpt_path)
|
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||||
state_dict = checkpoint["state_dict"]
|
state_dict = checkpoint["state_dict"]
|
||||||
|
if dtype is not None:
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if type(v) is torch.Tensor:
|
||||||
|
state_dict[k] = v.to(dtype)
|
||||||
|
|
||||||
# Convert the UNet2DConditionModel model.
|
# Convert the UNet2DConditionModel model.
|
||||||
unet_config = create_unet_diffusers_config()
|
unet_config = create_unet_diffusers_config(v2)
|
||||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
||||||
|
|
||||||
unet = UNet2DConditionModel(**unet_config)
|
unet = UNet2DConditionModel(**unet_config)
|
||||||
unet.load_state_dict(converted_unet_checkpoint)
|
info = unet.load_state_dict(converted_unet_checkpoint)
|
||||||
|
print("loading u-net:", info)
|
||||||
|
|
||||||
# Convert the VAE model.
|
# Convert the VAE model.
|
||||||
vae_config = create_vae_diffusers_config()
|
vae_config = create_vae_diffusers_config()
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
print("loadint vae:", info)
|
||||||
|
|
||||||
# convert text_model
|
# convert text_model
|
||||||
text_model = convert_ldm_clip_checkpoint(state_dict)
|
if v2:
|
||||||
|
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
||||||
|
cfg = CLIPTextConfig(
|
||||||
|
vocab_size=49408,
|
||||||
|
hidden_size=1024,
|
||||||
|
intermediate_size=4096,
|
||||||
|
num_hidden_layers=23,
|
||||||
|
num_attention_heads=16,
|
||||||
|
max_position_embeddings=77,
|
||||||
|
hidden_act="gelu",
|
||||||
|
layer_norm_eps=1e-05,
|
||||||
|
dropout=0.0,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
initializer_range=0.02,
|
||||||
|
initializer_factor=1.0,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
model_type="clip_text_model",
|
||||||
|
projection_dim=512,
|
||||||
|
torch_dtype="float32",
|
||||||
|
transformers_version="4.25.0.dev0",
|
||||||
|
)
|
||||||
|
text_model = CLIPTextModel._from_config(cfg)
|
||||||
|
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||||
|
else:
|
||||||
|
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
||||||
|
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||||
|
print("loading text encoder:", info)
|
||||||
|
|
||||||
return text_model, vae, unet
|
return text_model, vae, unet
|
||||||
|
|
||||||
|
|
||||||
def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
|
def convert_text_encoder_state_dict_to_sd_v2(checkpoint):
|
||||||
|
def convert_key(key):
|
||||||
|
# position_idsの除去
|
||||||
|
if ".position_ids" in key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# common
|
||||||
|
key = key.replace("text_model.encoder.", "transformer.")
|
||||||
|
key = key.replace("text_model.", "")
|
||||||
|
if "layers" in key:
|
||||||
|
# resblocks conversion
|
||||||
|
key = key.replace(".layers.", ".resblocks.")
|
||||||
|
if ".layer_norm" in key:
|
||||||
|
key = key.replace(".layer_norm", ".ln_")
|
||||||
|
elif ".mlp." in key:
|
||||||
|
key = key.replace(".fc1.", ".c_fc.")
|
||||||
|
key = key.replace(".fc2.", ".c_proj.")
|
||||||
|
elif '.self_attn.out_proj' in key:
|
||||||
|
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
||||||
|
elif '.self_attn.' in key:
|
||||||
|
key = None # 特殊なので後で処理する
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
||||||
|
elif '.position_embedding' in key:
|
||||||
|
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
||||||
|
elif '.token_embedding' in key:
|
||||||
|
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
||||||
|
elif 'final_layer_norm' in key:
|
||||||
|
key = key.replace("final_layer_norm", "ln_final")
|
||||||
|
return key
|
||||||
|
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
new_sd = {}
|
||||||
|
for key in keys:
|
||||||
|
new_key = convert_key(key)
|
||||||
|
if new_key is None:
|
||||||
|
continue
|
||||||
|
new_sd[new_key] = checkpoint[key]
|
||||||
|
|
||||||
|
# attnの変換
|
||||||
|
for key in keys:
|
||||||
|
if 'layers' in key and 'q_proj' in key:
|
||||||
|
# 三つを結合
|
||||||
|
key_q = key
|
||||||
|
key_k = key.replace("q_proj", "k_proj")
|
||||||
|
key_v = key.replace("q_proj", "v_proj")
|
||||||
|
|
||||||
|
value_q = checkpoint[key_q]
|
||||||
|
value_k = checkpoint[key_k]
|
||||||
|
value_v = checkpoint[key_v]
|
||||||
|
value = torch.cat([value_q, value_k, value_v])
|
||||||
|
|
||||||
|
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
||||||
|
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
||||||
|
new_sd[new_key] = value
|
||||||
|
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
|
||||||
|
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
|
||||||
# VAEがメモリ上にないので、もう一度VAEを含めて読み込む
|
# VAEがメモリ上にないので、もう一度VAEを含めて読み込む
|
||||||
checkpoint = load_checkpoint_with_conversion(ckpt_path)
|
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||||
state_dict = checkpoint["state_dict"]
|
state_dict = checkpoint["state_dict"]
|
||||||
|
|
||||||
|
def assign_new_sd(prefix, sd):
|
||||||
|
for k, v in sd.items():
|
||||||
|
key = prefix + k
|
||||||
|
assert key in state_dict, f"Illegal key in save SD: {key}"
|
||||||
|
if save_dtype is not None:
|
||||||
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
# Convert the UNet model
|
# Convert the UNet model
|
||||||
unet_state_dict = convert_unet_state_dict(unet.state_dict())
|
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
||||||
for k, v in unet_state_dict.items():
|
assign_new_sd("model.diffusion_model.", unet_state_dict)
|
||||||
key = "model.diffusion_model." + k
|
|
||||||
assert key in state_dict, f"Illegal key in save SD: {key}"
|
|
||||||
if save_dtype is not None:
|
|
||||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
|
||||||
state_dict[key] = v
|
|
||||||
|
|
||||||
# Convert the text encoder model
|
# Convert the text encoder model
|
||||||
text_enc_dict = text_encoder.state_dict() # 変換不要
|
if v2:
|
||||||
for k, v in text_enc_dict.items():
|
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict())
|
||||||
key = "cond_stage_model.transformer." + k
|
assign_new_sd("cond_stage_model.model.", text_enc_dict)
|
||||||
assert key in state_dict, f"Illegal key in save SD: {key}"
|
else:
|
||||||
if save_dtype is not None:
|
text_enc_dict = text_encoder.state_dict()
|
||||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
assign_new_sd("cond_stage_model.transformer.", text_enc_dict)
|
||||||
state_dict[key] = v
|
|
||||||
|
|
||||||
# Put together new checkpoint
|
# Put together new checkpoint
|
||||||
new_ckpt = {'state_dict': state_dict}
|
new_ckpt = {'state_dict': state_dict}
|
||||||
@ -718,6 +906,21 @@ def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path,
|
|||||||
new_ckpt['global_step'] = steps
|
new_ckpt['global_step'] = steps
|
||||||
|
|
||||||
torch.save(new_ckpt, output_file)
|
torch.save(new_ckpt, output_file)
|
||||||
|
|
||||||
|
|
||||||
|
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, save_dtype):
|
||||||
|
pipeline = StableDiffusionPipeline(
|
||||||
|
unet=unet,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
vae=AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae"),
|
||||||
|
scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
|
||||||
|
tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=None,
|
||||||
|
)
|
||||||
|
pipeline.save_pretrained(output_dir)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if args.in_json is not None:
|
if args.in_json is not None:
|
||||||
|
@ -10,7 +10,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if args.in_json is not None:
|
if args.in_json is not None:
|
||||||
|
@ -36,7 +36,7 @@ def get_latents(vae, images, weight_dtype):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if os.path.exists(args.in_json):
|
if os.path.exists(args.in_json):
|
||||||
@ -49,13 +49,11 @@ def main(args):
|
|||||||
|
|
||||||
# モデル形式のオプション設定を確認する
|
# モデル形式のオプション設定を確認する
|
||||||
use_stable_diffusion_format = os.path.isfile(args.model_name_or_path)
|
use_stable_diffusion_format = os.path.isfile(args.model_name_or_path)
|
||||||
if not use_stable_diffusion_format:
|
|
||||||
assert os.path.exists(args.model_name_or_path), f"no model / モデルがありません : {args.model_name_or_path}"
|
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
print("load StableDiffusion checkpoint")
|
print("load StableDiffusion checkpoint")
|
||||||
_, vae, _ = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.model_name_or_path)
|
_, vae, _ = fine_tuning_utils.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_name_or_path)
|
||||||
else:
|
else:
|
||||||
print("load Diffusers pretrained models")
|
print("load Diffusers pretrained models")
|
||||||
vae = AutoencoderKL.from_pretrained(args.model_name_or_path, subfolder="vae")
|
vae = AutoencoderKL.from_pretrained(args.model_name_or_path, subfolder="vae")
|
||||||
@ -73,7 +71,8 @@ def main(args):
|
|||||||
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
|
||||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||||
|
|
||||||
bucket_resos, bucket_aspect_ratios = fine_tuning_utils.make_bucket_resolutions(max_reso)
|
bucket_resos, bucket_aspect_ratios = fine_tuning_utils.make_bucket_resolutions(
|
||||||
|
max_reso, args.min_bucket_reso, args.max_bucket_reso)
|
||||||
|
|
||||||
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
|
||||||
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
|
||||||
@ -162,9 +161,13 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||||
|
parser.add_argument("--v2", action='store_true',
|
||||||
|
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||||
parser.add_argument("--max_resolution", type=str, default="512,512",
|
parser.add_argument("--max_resolution", type=str, default="512,512",
|
||||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
|
||||||
|
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||||
|
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||||
parser.add_argument("--mixed_precision", type=str, default="no",
|
parser.add_argument("--mixed_precision", type=str, default="no",
|
||||||
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user