2022-12-20 14:15:17 +00:00
import argparse
import os
import json
from tqdm import tqdm
import numpy as np
from PIL import Image
import cv2
import torch
from torchvision import transforms
import library . model_util as model_util
2023-02-03 19:40:03 +00:00
import library . train_util as train_util
2022-12-20 14:15:17 +00:00
DEVICE = torch . device ( ' cuda ' if torch . cuda . is_available ( ) else ' cpu ' )
IMAGE_TRANSFORMS = transforms . Compose (
[
transforms . ToTensor ( ) ,
transforms . Normalize ( [ 0.5 ] , [ 0.5 ] ) ,
]
)
2023-02-03 19:40:03 +00:00
def collate_fn_remove_corrupted ( batch ) :
""" Collate function that allows to remove corrupted examples in the
dataloader . It expects that the dataloader returns ' None ' when that occurs .
The ' None ' s in the batch are removed .
"""
# Filter out all the Nones (corrupted examples)
batch = list ( filter ( lambda x : x is not None , batch ) )
return batch
2022-12-20 14:15:17 +00:00
def get_latents ( vae , images , weight_dtype ) :
img_tensors = [ IMAGE_TRANSFORMS ( image ) for image in images ]
img_tensors = torch . stack ( img_tensors )
img_tensors = img_tensors . to ( DEVICE , weight_dtype )
with torch . no_grad ( ) :
latents = vae . encode ( img_tensors ) . latent_dist . sample ( ) . float ( ) . to ( " cpu " ) . numpy ( )
return latents
2023-02-03 19:40:03 +00:00
def get_npz_filename_wo_ext ( data_dir , image_key , is_full_path , flip ) :
if is_full_path :
base_name = os . path . splitext ( os . path . basename ( image_key ) ) [ 0 ]
else :
base_name = image_key
if flip :
base_name + = ' _flip '
return os . path . join ( data_dir , base_name )
2022-12-20 14:15:17 +00:00
def main ( args ) :
2023-02-03 19:40:03 +00:00
image_paths = train_util . glob_images ( args . train_data_dir )
2022-12-20 14:15:17 +00:00
print ( f " found { len ( image_paths ) } images. " )
if os . path . exists ( args . in_json ) :
print ( f " loading existing metadata: { args . in_json } " )
with open ( args . in_json , " rt " , encoding = ' utf-8 ' ) as f :
metadata = json . load ( f )
else :
print ( f " no metadata / メタデータファイルがありません: { args . in_json } " )
return
weight_dtype = torch . float32
if args . mixed_precision == " fp16 " :
weight_dtype = torch . float16
elif args . mixed_precision == " bf16 " :
weight_dtype = torch . bfloat16
vae = model_util . load_vae ( args . model_name_or_path , weight_dtype )
vae . eval ( )
vae . to ( DEVICE , dtype = weight_dtype )
# bucketのサイズを計算する
max_reso = tuple ( [ int ( t ) for t in args . max_resolution . split ( ' , ' ) ] )
assert len ( max_reso ) == 2 , f " illegal resolution (not ' width,height ' ) / 画像サイズに誤りがあります。 ' 幅,高さ ' で指定してください: { args . max_resolution } "
bucket_resos , bucket_aspect_ratios = model_util . make_bucket_resolutions (
max_reso , args . min_bucket_reso , args . max_bucket_reso )
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
bucket_aspect_ratios = np . array ( bucket_aspect_ratios )
buckets_imgs = [ [ ] for _ in range ( len ( bucket_resos ) ) ]
bucket_counts = [ 0 for _ in range ( len ( bucket_resos ) ) ]
img_ar_errors = [ ]
2023-02-03 19:40:03 +00:00
def process_batch ( is_last ) :
for j in range ( len ( buckets_imgs ) ) :
bucket = buckets_imgs [ j ]
if ( is_last and len ( bucket ) > 0 ) or len ( bucket ) > = args . batch_size :
latents = get_latents ( vae , [ img for _ , _ , img in bucket ] , weight_dtype )
for ( image_key , _ , _ ) , latent in zip ( bucket , latents ) :
npz_file_name = get_npz_filename_wo_ext ( args . train_data_dir , image_key , args . full_path , False )
np . savez ( npz_file_name , latent )
# flip
if args . flip_aug :
latents = get_latents ( vae , [ img [ : , : : - 1 ] . copy ( ) for _ , _ , img in bucket ] , weight_dtype ) # copyがないとTensor変換できない
for ( image_key , _ , _ ) , latent in zip ( bucket , latents ) :
npz_file_name = get_npz_filename_wo_ext ( args . train_data_dir , image_key , args . full_path , True )
np . savez ( npz_file_name , latent )
bucket . clear ( )
# 読み込みの高速化のためにDataLoaderを使うオプション
if args . max_data_loader_n_workers is not None :
dataset = train_util . ImageLoadingDataset ( image_paths )
data = torch . utils . data . DataLoader ( dataset , batch_size = 1 , shuffle = False ,
num_workers = args . max_data_loader_n_workers , collate_fn = collate_fn_remove_corrupted , drop_last = False )
else :
data = [ [ ( None , ip ) ] for ip in image_paths ]
for data_entry in tqdm ( data , smoothing = 0.0 ) :
if data_entry [ 0 ] is None :
continue
img_tensor , image_path = data_entry [ 0 ]
if img_tensor is not None :
image = transforms . functional . to_pil_image ( img_tensor )
else :
try :
image = Image . open ( image_path )
if image . mode != ' RGB ' :
image = image . convert ( " RGB " )
except Exception as e :
print ( f " Could not load image path / 画像を読み込めません: { image_path } , error: { e } " )
continue
2022-12-20 14:15:17 +00:00
image_key = image_path if args . full_path else os . path . splitext ( os . path . basename ( image_path ) ) [ 0 ]
if image_key not in metadata :
metadata [ image_key ] = { }
2023-02-03 19:40:03 +00:00
# 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変
2022-12-20 14:15:17 +00:00
aspect_ratio = image . width / image . height
ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np . abs ( ar_errors ) . argmin ( )
reso = bucket_resos [ bucket_id ]
ar_error = ar_errors [ bucket_id ]
img_ar_errors . append ( abs ( ar_error ) )
# どのサイズにリサイズするか→トリミングする方向で
if ar_error < = 0 : # 横が長い→縦を合わせる
scale = reso [ 1 ] / image . height
else :
scale = reso [ 0 ] / image . width
resized_size = ( int ( image . width * scale + .5 ) , int ( image . height * scale + .5 ) )
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
assert resized_size [ 0 ] == reso [ 0 ] or resized_size [ 1 ] == reso [
1 ] , f " internal error, resized size not match: { reso } , { resized_size } , { image . width } , { image . height } "
assert resized_size [ 0 ] > = reso [ 0 ] and resized_size [ 1 ] > = reso [
1 ] , f " internal error, resized size too small: { reso } , { resized_size } , { image . width } , { image . height } "
2023-02-03 19:40:03 +00:00
# 既に存在するファイルがあればshapeを確認して同じならskipする
if args . skip_existing :
npz_files = [ get_npz_filename_wo_ext ( args . train_data_dir , image_key , args . full_path , False ) + " .npz " ]
if args . flip_aug :
npz_files . append ( get_npz_filename_wo_ext ( args . train_data_dir , image_key , args . full_path , True ) + " .npz " )
found = True
for npz_file in npz_files :
if not os . path . exists ( npz_file ) :
found = False
break
dat = np . load ( npz_file ) [ ' arr_0 ' ]
if dat . shape [ 1 ] != reso [ 1 ] / / 8 or dat . shape [ 2 ] != reso [ 0 ] / / 8 : # latentsのshapeを確認
found = False
break
if found :
continue
2022-12-20 14:15:17 +00:00
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np . array ( image )
image = cv2 . resize ( image , resized_size , interpolation = cv2 . INTER_AREA )
if resized_size [ 0 ] > reso [ 0 ] :
trim_size = resized_size [ 0 ] - reso [ 0 ]
image = image [ : , trim_size / / 2 : trim_size / / 2 + reso [ 0 ] ]
elif resized_size [ 1 ] > reso [ 1 ] :
trim_size = resized_size [ 1 ] - reso [ 1 ]
image = image [ trim_size / / 2 : trim_size / / 2 + reso [ 1 ] ]
assert image . shape [ 0 ] == reso [ 1 ] and image . shape [ 1 ] == reso [ 0 ] , f " internal error, illegal trimmed size: { image . shape } , { reso } "
# # debug
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
# バッチへ追加
buckets_imgs [ bucket_id ] . append ( ( image_key , reso , image ) )
bucket_counts [ bucket_id ] + = 1
metadata [ image_key ] [ ' train_resolution ' ] = reso
# バッチを推論するか判定して推論する
2023-02-03 19:40:03 +00:00
process_batch ( False )
2022-12-20 14:15:17 +00:00
2023-02-03 19:40:03 +00:00
# 残りを処理する
process_batch ( True )
2022-12-20 14:15:17 +00:00
for i , ( reso , count ) in enumerate ( zip ( bucket_resos , bucket_counts ) ) :
print ( f " bucket { i } { reso } : { count } " )
img_ar_errors = np . array ( img_ar_errors )
print ( f " mean ar error: { np . mean ( img_ar_errors ) } " )
# metadataを書き出して終わり
print ( f " writing metadata: { args . out_json } " )
with open ( args . out_json , " wt " , encoding = ' utf-8 ' ) as f :
json . dump ( metadata , f , indent = 2 )
print ( " done! " )
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " train_data_dir " , type = str , help = " directory for train images / 学習画像データのディレクトリ " )
parser . add_argument ( " in_json " , type = str , help = " metadata file to input / 読み込むメタデータファイル " )
parser . add_argument ( " out_json " , type = str , help = " metadata file to output / メタデータファイル書き出し先 " )
parser . add_argument ( " model_name_or_path " , type = str , help = " model name or path to encode latents / latentを取得するためのモデル " )
parser . add_argument ( " --v2 " , action = ' store_true ' ,
2023-02-03 19:40:03 +00:00
help = ' not used (for backward compatibility) / 使用されません(互換性のため残してあります) ' )
2022-12-20 14:15:17 +00:00
parser . add_argument ( " --batch_size " , type = int , default = 1 , help = " batch size in inference / 推論時のバッチサイズ " )
2023-02-03 19:40:03 +00:00
parser . add_argument ( " --max_data_loader_n_workers " , type = int , default = None ,
help = " enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する( 読み込みを高速化) " )
2022-12-20 14:15:17 +00:00
parser . add_argument ( " --max_resolution " , type = str , default = " 512,512 " ,
help = " max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します) " )
parser . add_argument ( " --min_bucket_reso " , type = int , default = 256 , help = " minimum resolution for buckets / bucketの最小解像度 " )
parser . add_argument ( " --max_bucket_reso " , type = int , default = 1024 , help = " maximum resolution for buckets / bucketの最小解像度 " )
parser . add_argument ( " --mixed_precision " , type = str , default = " no " ,
choices = [ " no " , " fp16 " , " bf16 " ] , help = " use mixed precision / 混合精度を使う場合、その精度 " )
parser . add_argument ( " --full_path " , action = " store_true " ,
help = " use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応) " )
parser . add_argument ( " --flip_aug " , action = " store_true " ,
help = " flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する " )
2023-02-03 19:40:03 +00:00
parser . add_argument ( " --skip_existing " , action = " store_true " ,
help = " skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする( flip_aug有効時は通常、反転の両方が存在する画像をスキップ) " )
2022-12-20 14:15:17 +00:00
args = parser . parse_args ( )
main ( args )