from tqdm import tqdm from library import model_util import argparse from transformers import CLIPTokenizer import torch import library.model_util as model_util import lora TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def interrogate(args): # いろいろ準備する print(f"loading SD model: {args.sd_model}") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) print(f"loading LoRA: {args.model}") network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい has_te_weight = False for key in network.weights_sd.keys(): if 'lora_te' in key: has_te_weight = True break if not has_te_weight: print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") return del vae print("loading tokenizer") if args.v2: tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") else: tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2) text_encoder.to(DEVICE) text_encoder.eval() unet.to(DEVICE) unet.eval() # U-Netは呼び出さないので不要だけど # トークンをひとつひとつ当たっていく token_id_start = 0 token_id_end = max(tokenizer.all_special_ids) print(f"interrogate tokens are: {token_id_start} to {token_id_end}") def get_all_embeddings(text_encoder): embs = [] with torch.no_grad(): for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)): batch = [] for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)): tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id] # tokens = [tid] # こちらは結果がいまひとつ batch.append(tokens) # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1] # clip skip対応 batch = torch.tensor(batch).to(DEVICE) if args.clip_skip is None: encoder_hidden_states = text_encoder(batch)[0] else: enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True) encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.to("cpu") embs.extend(encoder_hidden_states) return torch.stack(embs) print("get original text encoder embeddings.") orig_embs = get_all_embeddings(text_encoder) network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) network.to(DEVICE) network.eval() print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)") print("get text encoder embeddings with lora.") lora_embs = get_all_embeddings(text_encoder) # 比べる:とりあえず単純に差分の絶対値で print("comparing...") diffs = {} for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): diff = torch.mean(torch.abs(orig_emb - lora_emb)) # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない diff = float(diff.detach().to('cpu').numpy()) diffs[token_id_start + i] = diff diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1]) # 結果を表示する print("top 100:") for i, (token, diff) in enumerate(diffs_sorted[:100]): # if diff < 1e-6: # break string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token])) print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("--v2", action='store_true', help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') parser.add_argument("--sd_model", type=str, default=None, help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors") parser.add_argument("--model", type=str, default=None, help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors") parser.add_argument("--batch_size", type=int, default=16, help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ") parser.add_argument("--clip_skip", type=int, default=None, help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") return parser if __name__ == '__main__': parser = setup_parser() args = parser.parse_args() interrogate(args)