129 lines
5.5 KiB
Python
129 lines
5.5 KiB
Python
|
||
|
||
from tqdm import tqdm
|
||
from library import model_util
|
||
import argparse
|
||
from transformers import CLIPTokenizer
|
||
import torch
|
||
|
||
import library.model_util as model_util
|
||
import lora
|
||
|
||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||
|
||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
|
||
def interrogate(args):
|
||
# いろいろ準備する
|
||
print(f"loading SD model: {args.sd_model}")
|
||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||
|
||
print(f"loading LoRA: {args.model}")
|
||
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||
|
||
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
||
has_te_weight = False
|
||
for key in network.weights_sd.keys():
|
||
if 'lora_te' in key:
|
||
has_te_weight = True
|
||
break
|
||
if not has_te_weight:
|
||
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
||
return
|
||
del vae
|
||
|
||
print("loading tokenizer")
|
||
if args.v2:
|
||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
||
else:
|
||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
||
|
||
text_encoder.to(DEVICE)
|
||
text_encoder.eval()
|
||
unet.to(DEVICE)
|
||
unet.eval() # U-Netは呼び出さないので不要だけど
|
||
|
||
# トークンをひとつひとつ当たっていく
|
||
token_id_start = 0
|
||
token_id_end = max(tokenizer.all_special_ids)
|
||
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
||
|
||
def get_all_embeddings(text_encoder):
|
||
embs = []
|
||
with torch.no_grad():
|
||
for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
|
||
batch = []
|
||
for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
|
||
tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
|
||
# tokens = [tid] # こちらは結果がいまひとつ
|
||
batch.append(tokens)
|
||
|
||
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
|
||
# clip skip対応
|
||
batch = torch.tensor(batch).to(DEVICE)
|
||
if args.clip_skip is None:
|
||
encoder_hidden_states = text_encoder(batch)[0]
|
||
else:
|
||
enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
|
||
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||
encoder_hidden_states = encoder_hidden_states.to("cpu")
|
||
|
||
embs.extend(encoder_hidden_states)
|
||
return torch.stack(embs)
|
||
|
||
print("get original text encoder embeddings.")
|
||
orig_embs = get_all_embeddings(text_encoder)
|
||
|
||
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
||
network.to(DEVICE)
|
||
network.eval()
|
||
|
||
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||
print("get text encoder embeddings with lora.")
|
||
lora_embs = get_all_embeddings(text_encoder)
|
||
|
||
# 比べる:とりあえず単純に差分の絶対値で
|
||
print("comparing...")
|
||
diffs = {}
|
||
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
|
||
diff = torch.mean(torch.abs(orig_emb - lora_emb))
|
||
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
|
||
diff = float(diff.detach().to('cpu').numpy())
|
||
diffs[token_id_start + i] = diff
|
||
|
||
diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
|
||
|
||
# 結果を表示する
|
||
print("top 100:")
|
||
for i, (token, diff) in enumerate(diffs_sorted[:100]):
|
||
# if diff < 1e-6:
|
||
# break
|
||
string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
|
||
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
|
||
|
||
|
||
def setup_parser() -> argparse.ArgumentParser:
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--v2", action='store_true',
|
||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||
parser.add_argument("--sd_model", type=str, default=None,
|
||
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
|
||
parser.add_argument("--model", type=str, default=None,
|
||
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
|
||
parser.add_argument("--batch_size", type=int, default=16,
|
||
help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
|
||
parser.add_argument("--clip_skip", type=int, default=None,
|
||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||
|
||
return parser
|
||
|
||
|
||
if __name__ == '__main__':
|
||
parser = setup_parser()
|
||
|
||
args = parser.parse_args()
|
||
interrogate(args)
|