2023-02-11 16:59:38 +00:00
from tqdm import tqdm
from library import model_util
import argparse
from transformers import CLIPTokenizer
import torch
import library . model_util as model_util
import lora
TOKENIZER_PATH = " openai/clip-vit-large-patch14 "
V2_STABLE_DIFFUSION_PATH = " stabilityai/stable-diffusion-2 " # ここからtokenizerだけ使う
DEVICE = torch . device ( ' cuda ' if torch . cuda . is_available ( ) else ' cpu ' )
def interrogate ( args ) :
# いろいろ準備する
print ( f " loading SD model: { args . sd_model } " )
text_encoder , vae , unet = model_util . load_models_from_stable_diffusion_checkpoint ( args . v2 , args . sd_model )
print ( f " loading LoRA: { args . model } " )
network = lora . create_network_from_weights ( 1.0 , args . model , vae , text_encoder , unet )
# text encoder向けの重みがあるかチェックする: 本当はlora側でやるのがいい
has_te_weight = False
for key in network . weights_sd . keys ( ) :
if ' lora_te ' in key :
has_te_weight = True
break
if not has_te_weight :
print ( " This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません " )
return
del vae
print ( " loading tokenizer " )
if args . v2 :
tokenizer : CLIPTokenizer = CLIPTokenizer . from_pretrained ( V2_STABLE_DIFFUSION_PATH , subfolder = " tokenizer " )
else :
tokenizer : CLIPTokenizer = CLIPTokenizer . from_pretrained ( TOKENIZER_PATH ) # , model_max_length=max_token_length + 2)
text_encoder . to ( DEVICE )
text_encoder . eval ( )
unet . to ( DEVICE )
unet . eval ( ) # U-Netは呼び出さないので不要だけど
# トークンをひとつひとつ当たっていく
token_id_start = 0
token_id_end = max ( tokenizer . all_special_ids )
print ( f " interrogate tokens are: { token_id_start } to { token_id_end } " )
def get_all_embeddings ( text_encoder ) :
embs = [ ]
with torch . no_grad ( ) :
for token_id in tqdm ( range ( token_id_start , token_id_end + 1 , args . batch_size ) ) :
batch = [ ]
for tid in range ( token_id , min ( token_id_end + 1 , token_id + args . batch_size ) ) :
tokens = [ tokenizer . bos_token_id , tid , tokenizer . eos_token_id ]
# tokens = [tid] # こちらは結果がいまひとつ
batch . append ( tokens )
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
# clip skip対応
batch = torch . tensor ( batch ) . to ( DEVICE )
if args . clip_skip is None :
encoder_hidden_states = text_encoder ( batch ) [ 0 ]
else :
enc_out = text_encoder ( batch , output_hidden_states = True , return_dict = True )
encoder_hidden_states = enc_out [ ' hidden_states ' ] [ - args . clip_skip ]
encoder_hidden_states = text_encoder . text_model . final_layer_norm ( encoder_hidden_states )
encoder_hidden_states = encoder_hidden_states . to ( " cpu " )
embs . extend ( encoder_hidden_states )
return torch . stack ( embs )
print ( " get original text encoder embeddings. " )
orig_embs = get_all_embeddings ( text_encoder )
network . apply_to ( text_encoder , unet , True , len ( network . unet_loras ) > 0 )
network . to ( DEVICE )
network . eval ( )
print ( " You can ignore warning messages start with ' _IncompatibleKeys ' (LoRA model does not have alpha because trained by older script) / ' _IncompatibleKeys ' の警告は無視して構いません( 以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません) " )
print ( " get text encoder embeddings with lora. " )
lora_embs = get_all_embeddings ( text_encoder )
# 比べる:とりあえず単純に差分の絶対値で
print ( " comparing... " )
diffs = { }
for i , ( orig_emb , lora_emb ) in enumerate ( zip ( orig_embs , tqdm ( lora_embs ) ) ) :
diff = torch . mean ( torch . abs ( orig_emb - lora_emb ) )
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
diff = float ( diff . detach ( ) . to ( ' cpu ' ) . numpy ( ) )
diffs [ token_id_start + i ] = diff
diffs_sorted = sorted ( diffs . items ( ) , key = lambda x : - x [ 1 ] )
# 結果を表示する
print ( " top 100: " )
for i , ( token , diff ) in enumerate ( diffs_sorted [ : 100 ] ) :
# if diff < 1e-6:
# break
string = tokenizer . convert_tokens_to_string ( tokenizer . convert_ids_to_tokens ( [ token ] ) )
print ( f " [ { i : 3d } ]: { token : 5d } { string : <20s } : { diff : .5f } " )
2023-03-22 00:20:57 +00:00
def setup_parser ( ) - > argparse . ArgumentParser :
2023-02-11 16:59:38 +00:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --v2 " , action = ' store_true ' ,
help = ' load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む ' )
parser . add_argument ( " --sd_model " , type = str , default = None ,
help = " Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors " )
parser . add_argument ( " --model " , type = str , default = None ,
help = " LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors " )
parser . add_argument ( " --batch_size " , type = int , default = 16 ,
help = " batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ " )
parser . add_argument ( " --clip_skip " , type = int , default = None ,
help = " use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる( nは1以上) " )
2023-03-22 00:20:57 +00:00
return parser
if __name__ == ' __main__ ' :
parser = setup_parser ( )
2023-02-11 16:59:38 +00:00
args = parser . parse_args ( )
interrogate ( args )