diff --git a/README.md b/README.md index 1673f24..b674e38 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,10 @@ Once you have created the LoRA network you can generate images via auto1111 by i ## Change history +* 2023/01/05 (v19.2): + - Add support for `--clip_skip` option + - Add missing `detect_face_rotate.py` to tools folder + - Add `gui.cmd` for easy start of GUI * 2023/01/02 (v19.2) update: - Finetune, add xformers, 8bit adam, min bucket, max bucket, batch size and flip augmentation support for dataset preparation - Finetune, add "Dataset preparation" tab to group task specific options diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 0bfed1e..01822d2 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -69,6 +69,7 @@ def save_configuration( prior_loss_weight, color_aug, flip_aug, + clip_skip, ): original_file_path = file_path @@ -123,6 +124,7 @@ def save_configuration( 'prior_loss_weight': prior_loss_weight, 'color_aug': color_aug, 'flip_aug': flip_aug, + 'clip_skip': clip_skip, } # Save the data to the selected file @@ -168,6 +170,7 @@ def open_configuration( prior_loss_weight, color_aug, flip_aug, + clip_skip, ): original_file_path = file_path @@ -223,6 +226,7 @@ def open_configuration( my_data.get('prior_loss_weight', prior_loss_weight), my_data.get('color_aug', color_aug), my_data.get('flip_aug', flip_aug), + my_data.get('clip_skip', clip_skip), ) @@ -261,6 +265,7 @@ def train_model( prior_loss_weight, color_aug, flip_aug, + clip_skip, ): def save_inference_file(output_dir, v2, v_parameterization): # Copy inference model for v2 if required @@ -424,6 +429,8 @@ def train_model( run_cmd += f' --resume={resume}' if not float(prior_loss_weight) == 1.0: run_cmd += f' --prior_loss_weight={prior_loss_weight}' + if clip_skip > 1: + run_cmd += f' --clip_skip={int(clip_skip)}' print(run_cmd) # Run the command @@ -774,6 +781,7 @@ def dreambooth_tab( shuffle_caption = gr.Checkbox( label='Shuffle caption', value=False ) + with gr.Row(): save_state = gr.Checkbox( label='Save training state', value=False ) @@ -786,6 +794,9 @@ def dreambooth_tab( inputs=[color_aug], outputs=[cache_latent_input], ) + clip_skip = gr.Slider( + label='Clip skip', value='1', minimum=1, maximum=12, step=1 + ) with gr.Row(): resume = gr.Textbox( label='Resume from saved training state', @@ -809,209 +820,66 @@ def dreambooth_tab( gradio_dataset_balancing_tab() button_run = gr.Button('Train model') + + settings_list = [ + pretrained_model_name_or_path_input, + v2_input, + v_parameterization_input, + logging_dir_input, + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + max_resolution_input, + learning_rate_input, + lr_scheduler_input, + lr_warmup_input, + train_batch_size_input, + epoch_input, + save_every_n_epochs_input, + mixed_precision_input, + save_precision_input, + seed_input, + num_cpu_threads_per_process_input, + cache_latent_input, + caption_extention_input, + enable_bucket_input, + gradient_checkpointing_input, + full_fp16_input, + no_token_padding_input, + stop_text_encoder_training_input, + use_8bit_adam_input, + xformers_input, + save_model_as_dropdown, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + ] button_open_config.click( open_configuration, - inputs=[ - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - ], - outputs=[ - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - ], + inputs=[config_file_name] + settings_list, + outputs=[config_file_name] + settings_list, ) button_save_config.click( save_configuration, - inputs=[ - dummy_db_false, - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - ], + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name], ) button_save_as_config.click( save_configuration, - inputs=[ - dummy_db_true, - config_file_name, - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - ], + inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name], ) button_run.click( train_model, - inputs=[ - pretrained_model_name_or_path_input, - v2_input, - v_parameterization_input, - logging_dir_input, - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - max_resolution_input, - learning_rate_input, - lr_scheduler_input, - lr_warmup_input, - train_batch_size_input, - epoch_input, - save_every_n_epochs_input, - mixed_precision_input, - save_precision_input, - seed_input, - num_cpu_threads_per_process_input, - cache_latent_input, - caption_extention_input, - enable_bucket_input, - gradient_checkpointing_input, - full_fp16_input, - no_token_padding_input, - stop_text_encoder_training_input, - use_8bit_adam_input, - xformers_input, - save_model_as_dropdown, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - ], + inputs=settings_list, ) return ( diff --git a/finetune_gui.py b/finetune_gui.py index 67ad39b..c7e279a 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -56,6 +56,7 @@ def save_configuration( caption_extension, use_8bit_adam, xformers, + clip_skip, ): original_file_path = file_path @@ -109,6 +110,7 @@ def save_configuration( 'caption_extension': caption_extension, 'use_8bit_adam': use_8bit_adam, 'xformers': xformers, + 'clip_skip': clip_skip, } # Save the data to the selected file @@ -153,6 +155,7 @@ def open_config_file( caption_extension, use_8bit_adam, xformers, + clip_skip, ): original_file_path = file_path file_path = get_file_path(file_path) @@ -206,6 +209,7 @@ def open_config_file( my_data.get('caption_extension', caption_extension), my_data.get('use_8bit_adam', use_8bit_adam), my_data.get('xformers', xformers), + my_data.get('clip_skip', clip_skip), ) @@ -243,6 +247,7 @@ def train_model( caption_extension, use_8bit_adam, xformers, + clip_skip, ): def save_inference_file(output_dir, v2, v_parameterization): # Copy inference model for v2 if required @@ -358,6 +363,8 @@ def train_model( run_cmd += f' --save_precision={save_precision}' if not save_model_as == 'same as source model': run_cmd += f' --save_model_as={save_model_as}' + if clip_skip > 1: + run_cmd += f' --clip_skip={int(clip_skip)}' print(run_cmd) # Run the command @@ -688,6 +695,9 @@ def finetune_tab(): with gr.Row(): use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=True) xformers = gr.Checkbox(label='Use xformers', value=True) + clip_skip = gr.Slider( + label='Clip skip', value='1', minimum=1, maximum=12, step=1 + ) with gr.Box(): with gr.Row(): create_caption = gr.Checkbox( @@ -733,6 +743,7 @@ def finetune_tab(): caption_extention_input, use_8bit_adam, xformers, + clip_skip, ] button_run.click(train_model, inputs=settings_list) diff --git a/gui.cmd b/gui.cmd new file mode 100644 index 0000000..379ff8d --- /dev/null +++ b/gui.cmd @@ -0,0 +1 @@ +.\venv\Scripts\python.exe kohya_gui.py \ No newline at end of file diff --git a/lora_gui.py b/lora_gui.py index e33c7da..c95bac8 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -72,6 +72,7 @@ def save_configuration( lora_network_weights, color_aug, flip_aug, + clip_skip, ): original_file_path = file_path @@ -129,6 +130,7 @@ def save_configuration( 'lora_network_weights': lora_network_weights, 'color_aug': color_aug, 'flip_aug': flip_aug, + 'clip_skip': clip_skip, } # Save the data to the selected file @@ -177,6 +179,7 @@ def open_configuration( lora_network_weights, color_aug, flip_aug, + clip_skip, ): original_file_path = file_path @@ -235,6 +238,7 @@ def open_configuration( my_data.get('lora_network_weights', lora_network_weights), my_data.get('color_aug', color_aug), my_data.get('flip_aug', flip_aug), + my_data.get('clip_skip', clip_skip), ) @@ -276,6 +280,7 @@ def train_model( lora_network_weights, color_aug, flip_aug, + clip_skip, ): def save_inference_file(output_dir, v2, v_parameterization): # Copy inference model for v2 if required @@ -361,13 +366,13 @@ def train_model( # Print the result # print(f"{total_steps} total steps") - if reg_data_dir == '': - reg_factor = 1 - else: - print( - 'Regularisation images are used... Will double the number of steps required...' - ) - reg_factor = 2 + # if reg_data_dir == '': + # reg_factor = 1 + # else: + # print( + # 'Regularisation images are used... Will double the number of steps required...' + # ) + # reg_factor = 2 # calculate max_train_steps max_train_steps = int( @@ -375,7 +380,7 @@ def train_model( float(total_steps) / int(train_batch_size) * int(epoch) - * int(reg_factor) + # * int(reg_factor) ) ) print(f'max_train_steps = {max_train_steps}') @@ -467,6 +472,8 @@ def train_model( run_cmd += f' --network_dim={network_dim}' if not lora_network_weights == '': run_cmd += f' --network_weights={lora_network_weights}' + if int(clip_skip) > 1: + run_cmd += f' --clip_skip={int(clip_skip)}' print(run_cmd) # Run the command @@ -860,6 +867,7 @@ def lora_tab( shuffle_caption = gr.Checkbox( label='Shuffle caption', value=False ) + with gr.Row(): save_state = gr.Checkbox( label='Save training state', value=False ) @@ -872,6 +880,9 @@ def lora_tab( inputs=[color_aug], outputs=[cache_latent_input], ) + clip_skip = gr.Slider( + label='Clip skip', value='1', minimum=1, maximum=12, step=1 + ) with gr.Row(): resume = gr.Textbox( label='Resume from saved training state', @@ -935,6 +946,7 @@ def lora_tab( lora_network_weights, color_aug, flip_aug, + clip_skip, ] button_open_config.click( diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py new file mode 100644 index 0000000..ef6d188 --- /dev/null +++ b/tools/detect_face_rotate.py @@ -0,0 +1,239 @@ +# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします +# (c) 2022 Kohya S. @kohya_ss + +# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す + +# v2: extract max face if multiple faces are found +# v3: add crop_ratio option +# v4: add multiple faces extraction and min/max size + +import argparse +import math +import cv2 +import glob +import os +from anime_face_detector import create_detector +from tqdm import tqdm +import numpy as np + +KP_REYE = 11 +KP_LEYE = 19 + +SCORE_THRES = 0.90 + + +def detect_faces(detector, image, min_size): + preds = detector(image) # bgr + # print(len(preds)) + + faces = [] + for pred in preds: + bb = pred['bbox'] + score = bb[-1] + if score < SCORE_THRES: + continue + + left, top, right, bottom = bb[:4] + cx = int((left + right) / 2) + cy = int((top + bottom) / 2) + fw = int(right - left) + fh = int(bottom - top) + + lex, ley = pred['keypoints'][KP_LEYE, 0:2] + rex, rey = pred['keypoints'][KP_REYE, 0:2] + angle = math.atan2(ley - rey, lex - rex) + angle = angle / math.pi * 180 + + faces.append((cx, cy, fw, fh, angle)) + + faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順 + return faces + + +def rotate_image(image, angle, cx, cy): + h, w = image.shape[0:2] + rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + + # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化 + # nh = max(h, int(w * math.sin(angle))) + # nw = max(w, int(h * math.sin(angle))) + # if nh > h or nw > w: + # pad_y = nh - h + # pad_t = pad_y // 2 + # pad_x = nw - w + # pad_l = pad_x // 2 + # m = np.array([[0, 0, pad_l], + # [0, 0, pad_t]]) + # rot_mat = rot_mat + m + # h, w = nh, nw + # cx += pad_l + # cy += pad_t + + result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) + return result, cx, cy + + +def process(args): + assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません" + assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" + + # アニメ顔検出モデルを読み込む + print("loading face detector.") + detector = create_detector('yolov3') + + # cropの引数を解析する + if args.crop_size is None: + crop_width = crop_height = None + else: + tokens = args.crop_size.split(',') + assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください" + crop_width, crop_height = [int(t) for t in tokens] + + if args.crop_ratio is None: + crop_h_ratio = crop_v_ratio = None + else: + tokens = args.crop_ratio.split(',') + assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください" + crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] + + # 画像を処理する + print("processing.") + output_extension = ".png" + + os.makedirs(args.dst_dir, exist_ok=True) + paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \ + glob.glob(os.path.join(args.src_dir, "*.webp")) + for path in tqdm(paths): + basename = os.path.splitext(os.path.basename(path))[0] + + # image = cv2.imread(path) # 日本語ファイル名でエラーになる + image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED) + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + if image.shape[2] == 4: + print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") + image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい + + h, w = image.shape[:2] + + faces = detect_faces(detector, image, args.multiple_faces) + for i, face in enumerate(faces): + cx, cy, fw, fh, angle = face + face_size = max(fw, fh) + if args.min_size is not None and face_size < args.min_size: + continue + if args.max_size is not None and face_size >= args.max_size: + continue + face_suffix = f"_{i+1:02d}" if args.multiple_faces else "" + + # オプション指定があれば回転する + face_img = image + if args.rotate: + face_img, cx, cy = rotate_image(face_img, angle, cx, cy) + + # オプション指定があれば顔を中心に切り出す + if crop_width is not None or crop_h_ratio is not None: + cur_crop_width, cur_crop_height = crop_width, crop_height + if crop_h_ratio is not None: + cur_crop_width = int(face_size * crop_h_ratio + .5) + cur_crop_height = int(face_size * crop_v_ratio + .5) + + # リサイズを必要なら行う + scale = 1.0 + if args.resize_face_size is not None: + # 顔サイズを基準にリサイズする + scale = args.resize_face_size / face_size + if scale < cur_crop_width / w: + print( + f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") + scale = cur_crop_width / w + if scale < cur_crop_height / h: + print( + f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") + scale = cur_crop_height / h + elif crop_h_ratio is not None: + # 倍率指定の時にはリサイズしない + pass + else: + # 切り出しサイズ指定あり + if w < cur_crop_width: + print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") + scale = cur_crop_width / w + if h < cur_crop_height: + print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") + scale = cur_crop_height / h + if args.resize_fit: + scale = max(cur_crop_width / w, cur_crop_height / h) + + if scale != 1.0: + w = int(w * scale + .5) + h = int(h * scale + .5) + face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) + cx = int(cx * scale + .5) + cy = int(cy * scale + .5) + fw = int(fw * scale + .5) + fh = int(fh * scale + .5) + + cur_crop_width = min(cur_crop_width, face_img.shape[1]) + cur_crop_height = min(cur_crop_height, face_img.shape[0]) + + x = cx - cur_crop_width // 2 + cx = cur_crop_width // 2 + if x < 0: + cx = cx + x + x = 0 + elif x + cur_crop_width > w: + cx = cx + (x + cur_crop_width - w) + x = w - cur_crop_width + face_img = face_img[:, x:x+cur_crop_width] + + y = cy - cur_crop_height // 2 + cy = cur_crop_height // 2 + if y < 0: + cy = cy + y + y = 0 + elif y + cur_crop_height > h: + cy = cy + (y + cur_crop_height - h) + y = h - cur_crop_height + face_img = face_img[y:y + cur_crop_height] + + # # debug + # print(path, cx, cy, angle) + # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) + # cv2.imshow("image", crp) + # if cv2.waitKey() == 27: + # break + # cv2.destroyAllWindows() + + # debug + if args.debug: + cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20) + + _, buf = cv2.imencode(output_extension, face_img) + with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f: + buf.tofile(f) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ") + parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ") + parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する") + parser.add_argument("--resize_fit", action="store_true", + help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする") + parser.add_argument("--resize_face_size", type=int, default=None, + help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする") + parser.add_argument("--crop_size", type=str, default=None, + help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す") + parser.add_argument("--crop_ratio", type=str, default=None, + help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す") + parser.add_argument("--min_size", type=int, default=None, + help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)") + parser.add_argument("--max_size", type=int, default=None, + help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)") + parser.add_argument("--multiple_faces", action="store_true", + help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す") + parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します") + args = parser.parse_args() + + process(args) \ No newline at end of file